Skip to content

vllm.model_executor.layers.quantization.modelopt

KV_CACHE_QUANT_ALGOS module-attribute

KV_CACHE_QUANT_ALGOS = ['FP8']

QUANT_ALGOS module-attribute

QUANT_ALGOS = ['FP8', 'NVFP4']

logger module-attribute

logger = init_logger(__name__)

ModelOptFp8Config

Bases: QuantizationConfig

Config class for ModelOpt FP8.

Source code in vllm/model_executor/layers/quantization/modelopt.py
class ModelOptFp8Config(QuantizationConfig):
    """Config class for ModelOpt FP8."""

    def __init__(
        self,
        is_checkpoint_fp8_serialized: bool = False,
        kv_cache_quant_method: Optional[str] = None,
        exclude_modules: Optional[list[str]] = None,
    ) -> None:
        super().__init__()
        self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
        self.kv_cache_quant_method = kv_cache_quant_method
        self.exclude_modules = exclude_modules or []
        if is_checkpoint_fp8_serialized:
            logger.warning(
                "Detected ModelOpt fp8 checkpoint. Please note that"
                " the format is experimental and could change."
            )

    @classmethod
    def get_name(cls) -> QuantizationMethods:
        return "modelopt"

    @classmethod
    def get_supported_act_dtypes(cls) -> list[torch.dtype]:
        return [torch.bfloat16, torch.half]

    @classmethod
    def get_min_capability(cls) -> int:
        return 89

    @classmethod
    def get_config_filenames(cls) -> list[str]:
        return ["hf_quant_config.json"]

    def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
        if self.exclude_modules is not None:
            self.exclude_modules = hf_to_vllm_mapper.apply_list(self.exclude_modules)

    @classmethod
    def override_quantization_method(
        cls, hf_quant_cfg, user_quant
    ) -> Optional[QuantizationMethods]:
        """Detect if this ModelOpt config should be used based on
        quantization config."""

        if hf_quant_cfg is None:
            return None

        # Use the community standard 'quant_method'
        quant_method = hf_quant_cfg.get("quant_method", "").lower()

        # Only proceed if the method is explicitly "modelopt"
        if quant_method != "modelopt":
            return None

        # Look for ModelOpt-specific config structure
        if "quantization" in hf_quant_cfg:
            quant_config = hf_quant_cfg["quantization"]
            if isinstance(quant_config, dict):
                quant_algo = quant_config.get("quant_algo", "")
                if "FP8" in quant_algo:
                    return "modelopt"
        else:
            # Check for compressed-tensors style config with specific quant_algo
            quant_algo = hf_quant_cfg.get("quant_algo", "")
            if isinstance(quant_algo, str) and "FP8" in quant_algo:
                return "modelopt"

        return None

    @classmethod
    def from_config(cls, config: dict[str, Any]) -> "ModelOptFp8Config":
        # Handle both ModelOpt format and compressed-tensors style format
        if "quantization" in config:
            # ModelOpt format: {"quantization": {"quant_algo": "..."}}
            quant_config = cls.get_from_keys(config, ["quantization"])
            if not isinstance(quant_config, dict):
                raise ValueError("Expected 'quantization' to be a dictionary in config")
            quant_method = quant_config.get("quant_algo", "")
            if not quant_method:
                raise ValueError("Missing 'quant_algo' in quantization config")
            kv_cache_quant_method = quant_config.get("kv_cache_quant_algo")
            # "exclude_modules" is the key in the legacy hf_quant_config.json
            exclude_modules = quant_config.get("exclude_modules")
        else:
            # Compressed-tensors style format:
            # {"quant_algo": "...", "quant_method": "modelopt"}
            quant_method = config.get("quant_algo", "")
            kv_cache_quant_method = config.get("kv_cache_quant_algo")
            # "ignore" is the key in config.json
            exclude_modules = config.get("ignore")

        if quant_method not in QUANT_ALGOS:
            raise ValueError(
                f"ModelOpt currently only supports: {QUANT_ALGOS} "
                "quantizations in vLLM. Please check the "
                "`hf_quant_config.json` file for your model's "
                "quant configuration."
            )
        is_checkpoint_fp8_serialized = "FP8" in quant_method

        return cls(is_checkpoint_fp8_serialized, kv_cache_quant_method, exclude_modules)

    def is_layer_excluded(self, prefix: str) -> bool:
        """
        Check if a layer should be excluded from quantization.
        Handles both exact matching (for fused layers) and substring matching.

        This method handles both regular models and multimodal models that use
        the language_model prefix. For multimodal models, it checks if the
        module name (without the language_model prefix) is in the exclude list.
        """
        if self.exclude_modules is None:
            return False

        # First check exact matching with fused layer support
        if is_layer_skipped(prefix, self.exclude_modules, self.packed_modules_mapping):
            return True

        # Then check substring matching for patterns not caught by exact match
        for module in self.exclude_modules:
            # Skip exact matches already handled above
            if module != prefix and (
                module in prefix
                or (
                    prefix.startswith("language_model.")
                    and module in prefix.removeprefix("language_model.")
                )
            ):
                return True
        return False

    def get_quant_method(
        self, layer: torch.nn.Module, prefix: str
    ) -> Optional["QuantizeMethodBase"]:
        from vllm.attention.layer import Attention  # Avoid circular import

        if isinstance(layer, LinearBase):
            if self.is_layer_excluded(prefix):
                return UnquantizedLinearMethod()
            # Check if this is a vision model layer that should not be quantized
            if "vision_tower" in prefix or "vision_model" in prefix:
                return UnquantizedLinearMethod()
            return ModelOptFp8LinearMethod(self)
        elif isinstance(layer, Attention):
            return ModelOptFp8KVCacheMethod(self)
        elif isinstance(layer, FusedMoE):
            return ModelOptFp8MoEMethod(self, layer)
        return None

exclude_modules instance-attribute

exclude_modules = exclude_modules or []

is_checkpoint_fp8_serialized instance-attribute

is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized

kv_cache_quant_method instance-attribute

kv_cache_quant_method = kv_cache_quant_method

__init__

__init__(
    is_checkpoint_fp8_serialized: bool = False,
    kv_cache_quant_method: Optional[str] = None,
    exclude_modules: Optional[list[str]] = None,
) -> None
Source code in vllm/model_executor/layers/quantization/modelopt.py
def __init__(
    self,
    is_checkpoint_fp8_serialized: bool = False,
    kv_cache_quant_method: Optional[str] = None,
    exclude_modules: Optional[list[str]] = None,
) -> None:
    super().__init__()
    self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
    self.kv_cache_quant_method = kv_cache_quant_method
    self.exclude_modules = exclude_modules or []
    if is_checkpoint_fp8_serialized:
        logger.warning(
            "Detected ModelOpt fp8 checkpoint. Please note that"
            " the format is experimental and could change."
        )

apply_vllm_mapper

apply_vllm_mapper(hf_to_vllm_mapper: WeightsMapper)
Source code in vllm/model_executor/layers/quantization/modelopt.py
def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
    if self.exclude_modules is not None:
        self.exclude_modules = hf_to_vllm_mapper.apply_list(self.exclude_modules)

from_config classmethod

from_config(config: dict[str, Any]) -> ModelOptFp8Config
Source code in vllm/model_executor/layers/quantization/modelopt.py
@classmethod
def from_config(cls, config: dict[str, Any]) -> "ModelOptFp8Config":
    # Handle both ModelOpt format and compressed-tensors style format
    if "quantization" in config:
        # ModelOpt format: {"quantization": {"quant_algo": "..."}}
        quant_config = cls.get_from_keys(config, ["quantization"])
        if not isinstance(quant_config, dict):
            raise ValueError("Expected 'quantization' to be a dictionary in config")
        quant_method = quant_config.get("quant_algo", "")
        if not quant_method:
            raise ValueError("Missing 'quant_algo' in quantization config")
        kv_cache_quant_method = quant_config.get("kv_cache_quant_algo")
        # "exclude_modules" is the key in the legacy hf_quant_config.json
        exclude_modules = quant_config.get("exclude_modules")
    else:
        # Compressed-tensors style format:
        # {"quant_algo": "...", "quant_method": "modelopt"}
        quant_method = config.get("quant_algo", "")
        kv_cache_quant_method = config.get("kv_cache_quant_algo")
        # "ignore" is the key in config.json
        exclude_modules = config.get("ignore")

    if quant_method not in QUANT_ALGOS:
        raise ValueError(
            f"ModelOpt currently only supports: {QUANT_ALGOS} "
            "quantizations in vLLM. Please check the "
            "`hf_quant_config.json` file for your model's "
            "quant configuration."
        )
    is_checkpoint_fp8_serialized = "FP8" in quant_method

    return cls(is_checkpoint_fp8_serialized, kv_cache_quant_method, exclude_modules)

get_config_filenames classmethod

get_config_filenames() -> list[str]
Source code in vllm/model_executor/layers/quantization/modelopt.py
@classmethod
def get_config_filenames(cls) -> list[str]:
    return ["hf_quant_config.json"]

get_min_capability classmethod

get_min_capability() -> int
Source code in vllm/model_executor/layers/quantization/modelopt.py
@classmethod
def get_min_capability(cls) -> int:
    return 89

get_name classmethod

get_name() -> QuantizationMethods
Source code in vllm/model_executor/layers/quantization/modelopt.py
@classmethod
def get_name(cls) -> QuantizationMethods:
    return "modelopt"

get_quant_method

get_quant_method(
    layer: Module, prefix: str
) -> Optional[QuantizeMethodBase]
Source code in vllm/model_executor/layers/quantization/modelopt.py
def get_quant_method(
    self, layer: torch.nn.Module, prefix: str
) -> Optional["QuantizeMethodBase"]:
    from vllm.attention.layer import Attention  # Avoid circular import

    if isinstance(layer, LinearBase):
        if self.is_layer_excluded(prefix):
            return UnquantizedLinearMethod()
        # Check if this is a vision model layer that should not be quantized
        if "vision_tower" in prefix or "vision_model" in prefix:
            return UnquantizedLinearMethod()
        return ModelOptFp8LinearMethod(self)
    elif isinstance(layer, Attention):
        return ModelOptFp8KVCacheMethod(self)
    elif isinstance(layer, FusedMoE):
        return ModelOptFp8MoEMethod(self, layer)
    return None

get_supported_act_dtypes classmethod

get_supported_act_dtypes() -> list[dtype]
Source code in vllm/model_executor/layers/quantization/modelopt.py
@classmethod
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
    return [torch.bfloat16, torch.half]

is_layer_excluded

is_layer_excluded(prefix: str) -> bool

Check if a layer should be excluded from quantization. Handles both exact matching (for fused layers) and substring matching.

This method handles both regular models and multimodal models that use the language_model prefix. For multimodal models, it checks if the module name (without the language_model prefix) is in the exclude list.

Source code in vllm/model_executor/layers/quantization/modelopt.py
def is_layer_excluded(self, prefix: str) -> bool:
    """
    Check if a layer should be excluded from quantization.
    Handles both exact matching (for fused layers) and substring matching.

    This method handles both regular models and multimodal models that use
    the language_model prefix. For multimodal models, it checks if the
    module name (without the language_model prefix) is in the exclude list.
    """
    if self.exclude_modules is None:
        return False

    # First check exact matching with fused layer support
    if is_layer_skipped(prefix, self.exclude_modules, self.packed_modules_mapping):
        return True

    # Then check substring matching for patterns not caught by exact match
    for module in self.exclude_modules:
        # Skip exact matches already handled above
        if module != prefix and (
            module in prefix
            or (
                prefix.startswith("language_model.")
                and module in prefix.removeprefix("language_model.")
            )
        ):
            return True
    return False

override_quantization_method classmethod

override_quantization_method(
    hf_quant_cfg, user_quant
) -> Optional[QuantizationMethods]

Detect if this ModelOpt config should be used based on quantization config.

Source code in vllm/model_executor/layers/quantization/modelopt.py
@classmethod
def override_quantization_method(
    cls, hf_quant_cfg, user_quant
) -> Optional[QuantizationMethods]:
    """Detect if this ModelOpt config should be used based on
    quantization config."""

    if hf_quant_cfg is None:
        return None

    # Use the community standard 'quant_method'
    quant_method = hf_quant_cfg.get("quant_method", "").lower()

    # Only proceed if the method is explicitly "modelopt"
    if quant_method != "modelopt":
        return None

    # Look for ModelOpt-specific config structure
    if "quantization" in hf_quant_cfg:
        quant_config = hf_quant_cfg["quantization"]
        if isinstance(quant_config, dict):
            quant_algo = quant_config.get("quant_algo", "")
            if "FP8" in quant_algo:
                return "modelopt"
    else:
        # Check for compressed-tensors style config with specific quant_algo
        quant_algo = hf_quant_cfg.get("quant_algo", "")
        if isinstance(quant_algo, str) and "FP8" in quant_algo:
            return "modelopt"

    return None

ModelOptFp8KVCacheMethod

Bases: BaseKVCacheMethod

Supports loading kv-cache scaling factors from FP8 checkpoints.

Source code in vllm/model_executor/layers/quantization/modelopt.py
class ModelOptFp8KVCacheMethod(BaseKVCacheMethod):
    """
    Supports loading kv-cache scaling factors from FP8 checkpoints.
    """

    def __init__(self, quant_config: Union[ModelOptFp8Config, ModelOptNvFp4Config]):
        super().__init__(quant_config)

__init__

__init__(
    quant_config: Union[
        ModelOptFp8Config, ModelOptNvFp4Config
    ],
)
Source code in vllm/model_executor/layers/quantization/modelopt.py
def __init__(self, quant_config: Union[ModelOptFp8Config, ModelOptNvFp4Config]):
    super().__init__(quant_config)

ModelOptFp8LinearMethod

Bases: LinearMethodBase

Linear method for Model Optimizer static quantization. Supports loading FP8 checkpoints with static weight scale and activation scale. Future support might be added for dynamic scales.

Limitations: 1. Only support per-tensor quantization due to torch._scaled_mm support. 2. Only support float8_e4m3fn datatype Args: quant_config: The ModelOpt quantization config.

Source code in vllm/model_executor/layers/quantization/modelopt.py
class ModelOptFp8LinearMethod(LinearMethodBase):
    """Linear method for Model Optimizer static quantization.
    Supports loading FP8 checkpoints with static weight scale and
    activation scale. Future support might be added for dynamic
    scales.

    Limitations:
    1. Only support per-tensor quantization due to torch._scaled_mm support.
    2. Only support float8_e4m3fn datatype
        Args: quant_config: The ModelOpt quantization config.
    """

    def __init__(self, quant_config: ModelOptFp8Config) -> None:
        self.quant_config = quant_config
        self.fp8_linear = Fp8LinearOp(
            act_quant_static=True, act_quant_group_shape=GroupShape.PER_TENSOR
        )

    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
        output_partition_sizes: list[int],
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        del input_size, output_size
        output_size_per_partition = sum(output_partition_sizes)
        weight_loader = extra_weight_attrs.get("weight_loader")
        layer.logical_widths = output_partition_sizes
        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition
        weight_dtype = (
            torch.float8_e4m3fn
            if self.quant_config.is_checkpoint_fp8_serialized
            else params_dtype
        )
        weight = ModelWeightParameter(
            data=torch.empty(
                output_size_per_partition, input_size_per_partition, dtype=weight_dtype
            ),
            input_dim=1,
            output_dim=0,
            weight_loader=weight_loader,
        )
        layer.register_parameter("weight", weight)

        if self.quant_config.is_checkpoint_fp8_serialized:
            # WEIGHT SCALE
            weight_scale = PerTensorScaleParameter(
                data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
                weight_loader=weight_loader,
            )
            weight_scale[:] = torch.finfo(torch.float32).min
            layer.register_parameter("weight_scale", weight_scale)
            # INPUT SCALE
            scale = PerTensorScaleParameter(
                data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
                weight_loader=weight_loader,
            )

            scale[:] = torch.finfo(torch.float32).min
            layer.register_parameter("input_scale", scale)

    def process_weights_after_loading(self, layer: Module) -> None:
        weight = layer.weight
        max_w_scale = layer.weight_scale.max()
        if not (layer.weight_scale == layer.weight_scale[0]).all():
            max_w_scale, weight = requantize_with_max_scale(
                layer.weight, layer.weight_scale, layer.logical_widths
            )
        layer.weight = Parameter(weight.t(), requires_grad=False)
        layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
        layer.input_scale = Parameter(layer.input_scale.max(), requires_grad=False)

    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        bias: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        return self.fp8_linear.apply(
            input=x,
            weight=layer.weight,
            weight_scale=layer.weight_scale,
            input_scale=layer.input_scale,
            bias=bias,
        )

fp8_linear instance-attribute

fp8_linear = Fp8LinearOp(
    act_quant_static=True, act_quant_group_shape=PER_TENSOR
)

quant_config instance-attribute

quant_config = quant_config

__init__

__init__(quant_config: ModelOptFp8Config) -> None
Source code in vllm/model_executor/layers/quantization/modelopt.py
def __init__(self, quant_config: ModelOptFp8Config) -> None:
    self.quant_config = quant_config
    self.fp8_linear = Fp8LinearOp(
        act_quant_static=True, act_quant_group_shape=GroupShape.PER_TENSOR
    )

apply

apply(
    layer: Module, x: Tensor, bias: Optional[Tensor] = None
) -> Tensor
Source code in vllm/model_executor/layers/quantization/modelopt.py
def apply(
    self,
    layer: torch.nn.Module,
    x: torch.Tensor,
    bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
    return self.fp8_linear.apply(
        input=x,
        weight=layer.weight,
        weight_scale=layer.weight_scale,
        input_scale=layer.input_scale,
        bias=bias,
    )

create_weights

create_weights(
    layer: Module,
    input_size_per_partition: int,
    output_partition_sizes: list[int],
    input_size: int,
    output_size: int,
    params_dtype: dtype,
    **extra_weight_attrs,
)
Source code in vllm/model_executor/layers/quantization/modelopt.py
def create_weights(
    self,
    layer: torch.nn.Module,
    input_size_per_partition: int,
    output_partition_sizes: list[int],
    input_size: int,
    output_size: int,
    params_dtype: torch.dtype,
    **extra_weight_attrs,
):
    del input_size, output_size
    output_size_per_partition = sum(output_partition_sizes)
    weight_loader = extra_weight_attrs.get("weight_loader")
    layer.logical_widths = output_partition_sizes
    layer.input_size_per_partition = input_size_per_partition
    layer.output_size_per_partition = output_size_per_partition
    weight_dtype = (
        torch.float8_e4m3fn
        if self.quant_config.is_checkpoint_fp8_serialized
        else params_dtype
    )
    weight = ModelWeightParameter(
        data=torch.empty(
            output_size_per_partition, input_size_per_partition, dtype=weight_dtype
        ),
        input_dim=1,
        output_dim=0,
        weight_loader=weight_loader,
    )
    layer.register_parameter("weight", weight)

    if self.quant_config.is_checkpoint_fp8_serialized:
        # WEIGHT SCALE
        weight_scale = PerTensorScaleParameter(
            data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
            weight_loader=weight_loader,
        )
        weight_scale[:] = torch.finfo(torch.float32).min
        layer.register_parameter("weight_scale", weight_scale)
        # INPUT SCALE
        scale = PerTensorScaleParameter(
            data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
            weight_loader=weight_loader,
        )

        scale[:] = torch.finfo(torch.float32).min
        layer.register_parameter("input_scale", scale)

process_weights_after_loading

process_weights_after_loading(layer: Module) -> None
Source code in vllm/model_executor/layers/quantization/modelopt.py
def process_weights_after_loading(self, layer: Module) -> None:
    weight = layer.weight
    max_w_scale = layer.weight_scale.max()
    if not (layer.weight_scale == layer.weight_scale[0]).all():
        max_w_scale, weight = requantize_with_max_scale(
            layer.weight, layer.weight_scale, layer.logical_widths
        )
    layer.weight = Parameter(weight.t(), requires_grad=False)
    layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
    layer.input_scale = Parameter(layer.input_scale.max(), requires_grad=False)

ModelOptFp8MoEMethod

Bases: FusedMoEMethodBase

MoE method for ModelOpt FP8. Supports loading FP8 checkpoints with static weight scale and activation scale. Args: quant_config: The ModelOpt quantization config.

Source code in vllm/model_executor/layers/quantization/modelopt.py
class ModelOptFp8MoEMethod(FusedMoEMethodBase):
    """MoE method for ModelOpt FP8.
    Supports loading FP8 checkpoints with static weight scale and
    activation scale.
    Args:
        quant_config: The ModelOpt quantization config.
    """

    def __init__(
        self,
        quant_config: ModelOptFp8Config,
        layer: torch.nn.Module,
    ) -> None:
        super().__init__(layer.moe_config)
        self.layer = layer
        self.quant_config = quant_config
        from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
            cutlass_fp8_supported,
        )

        self.cutlass_fp8_supported = cutlass_fp8_supported()
        self.flashinfer_moe_backend: Optional[FlashinferMoeBackend] = None
        if envs.VLLM_USE_FLASHINFER_MOE_FP8 and has_flashinfer_moe():
            self.flashinfer_moe_backend = get_flashinfer_moe_backend()
            logger.info_once(
                f"Using FlashInfer {self.flashinfer_moe_backend.value} kernels"
            )

    def maybe_make_prepare_finalize(
        self,
    ) -> Optional[mk.FusedMoEPrepareAndFinalize]:
        # TRT LLM not supported with all2all yet.
        if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
            return None
        elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
            prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize(
                self.moe
            )
            logger.debug_once("%s", prepare_finalize.__class__.__name__)
            return prepare_finalize
        else:
            return super().maybe_make_prepare_finalize()

    def select_gemm_impl(
        self,
        prepare_finalize: mk.FusedMoEPrepareAndFinalize,
        layer: torch.nn.Module,
    ) -> mk.FusedMoEPermuteExpertsUnpermute:
        assert self.moe_quant_config is not None
        experts = select_cutlass_fp8_gemm_impl(
            self.moe,
            self.moe_quant_config,
        )
        logger.debug_once("Using %s", experts.__class__.__name__)
        return experts

    def create_weights(
        self,
        layer: torch.nn.Module,
        num_experts: int,
        hidden_size: int,
        intermediate_size_per_partition: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        # Use FP8 dtype if checkpoint is serialized
        weight_dtype = (
            torch.float8_e4m3fn
            if self.quant_config.is_checkpoint_fp8_serialized
            else params_dtype
        )
        weight_loader = extra_weight_attrs.get("weight_loader")

        w13_weight = ModelWeightParameter(
            data=torch.empty(
                num_experts,
                2 * intermediate_size_per_partition,
                hidden_size,
                dtype=weight_dtype,
            ),
            input_dim=2,
            output_dim=1,
            weight_loader=weight_loader,
        )
        layer.register_parameter("w13_weight", w13_weight)

        w2_weight = ModelWeightParameter(
            data=torch.empty(
                num_experts,
                hidden_size,
                intermediate_size_per_partition,
                dtype=weight_dtype,
            ),
            input_dim=2,
            output_dim=1,
            weight_loader=weight_loader,
        )
        layer.register_parameter("w2_weight", w2_weight)

        if self.quant_config.is_checkpoint_fp8_serialized:
            # WEIGHT SCALES - Per-tensor scaling for ModelOpts
            # Allocate 2 scales for w1 and w3 respectively.
            # They will be combined to a single scale after weight loading.
            w13_weight_scale = PerTensorScaleParameter(
                data=torch.full(
                    (num_experts, 2),
                    1.0,
                    dtype=torch.float32,
                ),
                weight_loader=weight_loader,
            )
            w2_weight_scale = PerTensorScaleParameter(
                data=torch.full((num_experts,), 1.0, dtype=torch.float32),
                weight_loader=weight_loader,
            )
            layer.register_parameter("w13_weight_scale", w13_weight_scale)
            layer.register_parameter("w2_weight_scale", w2_weight_scale)

            # Set weight loader attributes for scales
            extra_weight_attrs.update(
                {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
            )

            # INPUT SCALES - Per-tensor scaling for ModelOpt
            w13_input_scale = PerTensorScaleParameter(
                data=torch.full((num_experts,), 1.0, dtype=torch.float32),
                weight_loader=weight_loader,
            )
            w2_input_scale = PerTensorScaleParameter(
                data=torch.full((num_experts,), 1.0, dtype=torch.float32),
                weight_loader=weight_loader,
            )
            layer.register_parameter("w13_input_scale", w13_input_scale)
            layer.register_parameter("w2_input_scale", w2_input_scale)

    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
        """Process FP8 MoE weights after loading from serialized checkpoint.
        Only supports pre-quantized checkpoints with FP8 weights and scales.
        """

        layer.w13_weight = Parameter(layer.w13_weight.data, requires_grad=False)
        layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False)

        from vllm._custom_ops import scaled_fp8_quant
        from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
            per_tensor_dequantize,
        )

        # Handle scale parameters
        if hasattr(layer, "w13_weight_scale") and layer.w13_weight_scale is not None:
            # Fp8 moe kernel needs single weight scale for w13 per expert.
            # We take the max of the w1 and w3 scales
            # then dequant and requant each expert.
            if layer.w13_weight_scale.dim() == 2:
                # Get the maximum scale across w1 and w3 for each expert
                max_w13_scales = layer.w13_weight_scale.max(dim=1).values

                # Requantize each expert's weights using the combined scale
                # w13_weight (num_experts, 2 * intermediate_size, hidden_size)
                # where the first intermediate_size rows are w1, the next are w3
                intermediate_size = layer.w13_weight.shape[1] // 2
                for expert_id in range(layer.w13_weight.shape[0]):
                    start = 0
                    for shard_id in range(2):  # w1 and w3
                        # Dequantize using the original scale for this shard
                        dq_weight = per_tensor_dequantize(
                            layer.w13_weight[expert_id][
                                start : start + intermediate_size, :
                            ],
                            layer.w13_weight_scale[expert_id][shard_id],
                        )
                        # Requantize using the combined max scale

                        (
                            layer.w13_weight[expert_id][
                                start : start + intermediate_size, :
                            ],
                            _,
                        ) = scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])

                        start += intermediate_size

                # Update the scale parameter to be per-expert
                layer.w13_weight_scale = Parameter(max_w13_scales, requires_grad=False)
            else:
                layer.w13_weight_scale = Parameter(
                    layer.w13_weight_scale.data, requires_grad=False
                )

        if hasattr(layer, "w2_weight_scale") and layer.w2_weight_scale is not None:
            layer.w2_weight_scale = Parameter(
                layer.w2_weight_scale.data, requires_grad=False
            )
        # Input scales must be equal for each expert in fp8 MoE layers.
        if hasattr(layer, "w13_input_scale") and layer.w13_input_scale is not None:
            layer.w13_input_scale = Parameter(
                layer.w13_input_scale.max(), requires_grad=False
            )
        if hasattr(layer, "w2_input_scale") and layer.w2_input_scale is not None:
            layer.w2_input_scale = Parameter(
                layer.w2_input_scale.max(), requires_grad=False
            )

        if self.flashinfer_moe_backend is not None:
            layer.w13_weight.data = swap_w13_to_w31(layer.w13_weight.data)
            register_moe_scaling_factors(layer)
            if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
                rotate_flashinfer_fp8_moe_weights(layer.w13_weight, layer.w2_weight)

    def get_fused_moe_quant_config(
        self, layer: torch.nn.Module
    ) -> Optional[FusedMoEQuantConfig]:
        if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
            return None

        return fp8_w8a8_moe_quant_config(
            w1_scale=layer.w13_weight_scale,
            w2_scale=layer.w2_weight_scale,
            a1_scale=layer.w13_input_scale,
            a2_scale=layer.w2_input_scale,
            per_act_token_quant=False,
        )

    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        router_logits: torch.Tensor,
        top_k: int,
        renormalize: bool,
        use_grouped_topk: bool = False,
        topk_group: Optional[int] = None,
        num_expert_group: Optional[int] = None,
        global_num_experts: int = -1,
        expert_map: Optional[torch.Tensor] = None,
        custom_routing_function: Optional[Callable] = None,
        scoring_func: str = "softmax",
        routed_scaling_factor: float = 1.0,
        e_score_correction_bias: Optional[torch.Tensor] = None,
        apply_router_weight_on_input: bool = False,
        activation: str = "silu",
        enable_eplb: bool = False,
        expert_load_view: Optional[torch.Tensor] = None,
        logical_to_physical_map: Optional[torch.Tensor] = None,
        logical_replica_count: Optional[torch.Tensor] = None,
    ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
        if enable_eplb:
            raise NotImplementedError(
                "EPLB not supported for `ModelOptFp8MoEMethod` yet."
            )

        if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
            assert self.fused_experts is None
            assert activation == "silu", (
                f"Expected 'silu' activation but got {activation}"
            )
            assert not renormalize
            return apply_flashinfer_per_tensor_scale_fp8(
                layer=layer,
                hidden_states=x,
                router_logits=router_logits,
                routing_bias=e_score_correction_bias,
                global_num_experts=global_num_experts,
                top_k=top_k,
                num_expert_group=num_expert_group,
                topk_group=topk_group,
                apply_router_weight_on_input=apply_router_weight_on_input,
            )

        # Expert selection
        topk_weights, topk_ids, _ = FusedMoE.select_experts(
            hidden_states=x,
            router_logits=router_logits,
            use_grouped_topk=use_grouped_topk,
            top_k=top_k,
            renormalize=renormalize,
            topk_group=topk_group,
            num_expert_group=num_expert_group,
            custom_routing_function=custom_routing_function,
            scoring_func=scoring_func,
            routed_scaling_factor=routed_scaling_factor,
            e_score_correction_bias=e_score_correction_bias,
            indices_type=self.topk_indices_dtype,
        )

        #
        # Note: the order here is important. self.fused_experts can override
        # cutlass or fused_experts.
        #
        if self.fused_experts is not None:
            return self.fused_experts(
                x,
                layer.w13_weight,
                layer.w2_weight,
                topk_weights,
                topk_ids,
                inplace=False,
                activation=activation,
                global_num_experts=global_num_experts,
                expert_map=expert_map,
                apply_router_weight_on_input=apply_router_weight_on_input,
            )
        elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
            assert not renormalize
            assert activation == "silu", (
                f"Expected 'silu' activation but got {activation}"
            )
            return flashinfer_cutlass_moe_fp8(
                x,
                layer,
                topk_weights,
                topk_ids,
                inplace=False,
                activation=activation,
                global_num_experts=global_num_experts,
                expert_map=expert_map,
                apply_router_weight_on_input=apply_router_weight_on_input,
            )
        else:
            from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts

            assert self.moe_quant_config is not None

            return fused_experts(
                x,
                layer.w13_weight,
                layer.w2_weight,
                topk_weights=topk_weights,
                topk_ids=topk_ids,
                inplace=True,
                activation=activation,
                quant_config=self.moe_quant_config,
                global_num_experts=global_num_experts,
                expert_map=expert_map,
                apply_router_weight_on_input=apply_router_weight_on_input,
            )

cutlass_fp8_supported instance-attribute

cutlass_fp8_supported = cutlass_fp8_supported()

flashinfer_moe_backend instance-attribute

flashinfer_moe_backend: Optional[FlashinferMoeBackend] = (
    None
)

layer instance-attribute

layer = layer

quant_config instance-attribute

quant_config = quant_config

__init__

__init__(
    quant_config: ModelOptFp8Config, layer: Module
) -> None
Source code in vllm/model_executor/layers/quantization/modelopt.py
def __init__(
    self,
    quant_config: ModelOptFp8Config,
    layer: torch.nn.Module,
) -> None:
    super().__init__(layer.moe_config)
    self.layer = layer
    self.quant_config = quant_config
    from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
        cutlass_fp8_supported,
    )

    self.cutlass_fp8_supported = cutlass_fp8_supported()
    self.flashinfer_moe_backend: Optional[FlashinferMoeBackend] = None
    if envs.VLLM_USE_FLASHINFER_MOE_FP8 and has_flashinfer_moe():
        self.flashinfer_moe_backend = get_flashinfer_moe_backend()
        logger.info_once(
            f"Using FlashInfer {self.flashinfer_moe_backend.value} kernels"
        )

apply

apply(
    layer: Module,
    x: Tensor,
    router_logits: Tensor,
    top_k: int,
    renormalize: bool,
    use_grouped_topk: bool = False,
    topk_group: Optional[int] = None,
    num_expert_group: Optional[int] = None,
    global_num_experts: int = -1,
    expert_map: Optional[Tensor] = None,
    custom_routing_function: Optional[Callable] = None,
    scoring_func: str = "softmax",
    routed_scaling_factor: float = 1.0,
    e_score_correction_bias: Optional[Tensor] = None,
    apply_router_weight_on_input: bool = False,
    activation: str = "silu",
    enable_eplb: bool = False,
    expert_load_view: Optional[Tensor] = None,
    logical_to_physical_map: Optional[Tensor] = None,
    logical_replica_count: Optional[Tensor] = None,
) -> Union[Tensor, tuple[Tensor, Tensor]]
Source code in vllm/model_executor/layers/quantization/modelopt.py
def apply(
    self,
    layer: torch.nn.Module,
    x: torch.Tensor,
    router_logits: torch.Tensor,
    top_k: int,
    renormalize: bool,
    use_grouped_topk: bool = False,
    topk_group: Optional[int] = None,
    num_expert_group: Optional[int] = None,
    global_num_experts: int = -1,
    expert_map: Optional[torch.Tensor] = None,
    custom_routing_function: Optional[Callable] = None,
    scoring_func: str = "softmax",
    routed_scaling_factor: float = 1.0,
    e_score_correction_bias: Optional[torch.Tensor] = None,
    apply_router_weight_on_input: bool = False,
    activation: str = "silu",
    enable_eplb: bool = False,
    expert_load_view: Optional[torch.Tensor] = None,
    logical_to_physical_map: Optional[torch.Tensor] = None,
    logical_replica_count: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
    if enable_eplb:
        raise NotImplementedError(
            "EPLB not supported for `ModelOptFp8MoEMethod` yet."
        )

    if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
        assert self.fused_experts is None
        assert activation == "silu", (
            f"Expected 'silu' activation but got {activation}"
        )
        assert not renormalize
        return apply_flashinfer_per_tensor_scale_fp8(
            layer=layer,
            hidden_states=x,
            router_logits=router_logits,
            routing_bias=e_score_correction_bias,
            global_num_experts=global_num_experts,
            top_k=top_k,
            num_expert_group=num_expert_group,
            topk_group=topk_group,
            apply_router_weight_on_input=apply_router_weight_on_input,
        )

    # Expert selection
    topk_weights, topk_ids, _ = FusedMoE.select_experts(
        hidden_states=x,
        router_logits=router_logits,
        use_grouped_topk=use_grouped_topk,
        top_k=top_k,
        renormalize=renormalize,
        topk_group=topk_group,
        num_expert_group=num_expert_group,
        custom_routing_function=custom_routing_function,
        scoring_func=scoring_func,
        routed_scaling_factor=routed_scaling_factor,
        e_score_correction_bias=e_score_correction_bias,
        indices_type=self.topk_indices_dtype,
    )

    #
    # Note: the order here is important. self.fused_experts can override
    # cutlass or fused_experts.
    #
    if self.fused_experts is not None:
        return self.fused_experts(
            x,
            layer.w13_weight,
            layer.w2_weight,
            topk_weights,
            topk_ids,
            inplace=False,
            activation=activation,
            global_num_experts=global_num_experts,
            expert_map=expert_map,
            apply_router_weight_on_input=apply_router_weight_on_input,
        )
    elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
        assert not renormalize
        assert activation == "silu", (
            f"Expected 'silu' activation but got {activation}"
        )
        return flashinfer_cutlass_moe_fp8(
            x,
            layer,
            topk_weights,
            topk_ids,
            inplace=False,
            activation=activation,
            global_num_experts=global_num_experts,
            expert_map=expert_map,
            apply_router_weight_on_input=apply_router_weight_on_input,
        )
    else:
        from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts

        assert self.moe_quant_config is not None

        return fused_experts(
            x,
            layer.w13_weight,
            layer.w2_weight,
            topk_weights=topk_weights,
            topk_ids=topk_ids,
            inplace=True,
            activation=activation,
            quant_config=self.moe_quant_config,
            global_num_experts=global_num_experts,
            expert_map=expert_map,
            apply_router_weight_on_input=apply_router_weight_on_input,
        )

create_weights

create_weights(
    layer: Module,
    num_experts: int,
    hidden_size: int,
    intermediate_size_per_partition: int,
    params_dtype: dtype,
    **extra_weight_attrs,
)
Source code in vllm/model_executor/layers/quantization/modelopt.py
def create_weights(
    self,
    layer: torch.nn.Module,
    num_experts: int,
    hidden_size: int,
    intermediate_size_per_partition: int,
    params_dtype: torch.dtype,
    **extra_weight_attrs,
):
    # Use FP8 dtype if checkpoint is serialized
    weight_dtype = (
        torch.float8_e4m3fn
        if self.quant_config.is_checkpoint_fp8_serialized
        else params_dtype
    )
    weight_loader = extra_weight_attrs.get("weight_loader")

    w13_weight = ModelWeightParameter(
        data=torch.empty(
            num_experts,
            2 * intermediate_size_per_partition,
            hidden_size,
            dtype=weight_dtype,
        ),
        input_dim=2,
        output_dim=1,
        weight_loader=weight_loader,
    )
    layer.register_parameter("w13_weight", w13_weight)

    w2_weight = ModelWeightParameter(
        data=torch.empty(
            num_experts,
            hidden_size,
            intermediate_size_per_partition,
            dtype=weight_dtype,
        ),
        input_dim=2,
        output_dim=1,
        weight_loader=weight_loader,
    )
    layer.register_parameter("w2_weight", w2_weight)

    if self.quant_config.is_checkpoint_fp8_serialized:
        # WEIGHT SCALES - Per-tensor scaling for ModelOpts
        # Allocate 2 scales for w1 and w3 respectively.
        # They will be combined to a single scale after weight loading.
        w13_weight_scale = PerTensorScaleParameter(
            data=torch.full(
                (num_experts, 2),
                1.0,
                dtype=torch.float32,
            ),
            weight_loader=weight_loader,
        )
        w2_weight_scale = PerTensorScaleParameter(
            data=torch.full((num_experts,), 1.0, dtype=torch.float32),
            weight_loader=weight_loader,
        )
        layer.register_parameter("w13_weight_scale", w13_weight_scale)
        layer.register_parameter("w2_weight_scale", w2_weight_scale)

        # Set weight loader attributes for scales
        extra_weight_attrs.update(
            {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
        )

        # INPUT SCALES - Per-tensor scaling for ModelOpt
        w13_input_scale = PerTensorScaleParameter(
            data=torch.full((num_experts,), 1.0, dtype=torch.float32),
            weight_loader=weight_loader,
        )
        w2_input_scale = PerTensorScaleParameter(
            data=torch.full((num_experts,), 1.0, dtype=torch.float32),
            weight_loader=weight_loader,
        )
        layer.register_parameter("w13_input_scale", w13_input_scale)
        layer.register_parameter("w2_input_scale", w2_input_scale)

get_fused_moe_quant_config

get_fused_moe_quant_config(
    layer: Module,
) -> Optional[FusedMoEQuantConfig]
Source code in vllm/model_executor/layers/quantization/modelopt.py
def get_fused_moe_quant_config(
    self, layer: torch.nn.Module
) -> Optional[FusedMoEQuantConfig]:
    if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
        return None

    return fp8_w8a8_moe_quant_config(
        w1_scale=layer.w13_weight_scale,
        w2_scale=layer.w2_weight_scale,
        a1_scale=layer.w13_input_scale,
        a2_scale=layer.w2_input_scale,
        per_act_token_quant=False,
    )

maybe_make_prepare_finalize

maybe_make_prepare_finalize() -> Optional[
    FusedMoEPrepareAndFinalize
]
Source code in vllm/model_executor/layers/quantization/modelopt.py
def maybe_make_prepare_finalize(
    self,
) -> Optional[mk.FusedMoEPrepareAndFinalize]:
    # TRT LLM not supported with all2all yet.
    if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
        return None
    elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
        prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize(
            self.moe
        )
        logger.debug_once("%s", prepare_finalize.__class__.__name__)
        return prepare_finalize
    else:
        return super().maybe_make_prepare_finalize()

process_weights_after_loading

process_weights_after_loading(layer: Module) -> None

Process FP8 MoE weights after loading from serialized checkpoint. Only supports pre-quantized checkpoints with FP8 weights and scales.

Source code in vllm/model_executor/layers/quantization/modelopt.py
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
    """Process FP8 MoE weights after loading from serialized checkpoint.
    Only supports pre-quantized checkpoints with FP8 weights and scales.
    """

    layer.w13_weight = Parameter(layer.w13_weight.data, requires_grad=False)
    layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False)

    from vllm._custom_ops import scaled_fp8_quant
    from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
        per_tensor_dequantize,
    )

    # Handle scale parameters
    if hasattr(layer, "w13_weight_scale") and layer.w13_weight_scale is not None:
        # Fp8 moe kernel needs single weight scale for w13 per expert.
        # We take the max of the w1 and w3 scales
        # then dequant and requant each expert.
        if layer.w13_weight_scale.dim() == 2:
            # Get the maximum scale across w1 and w3 for each expert
            max_w13_scales = layer.w13_weight_scale.max(dim=1).values

            # Requantize each expert's weights using the combined scale
            # w13_weight (num_experts, 2 * intermediate_size, hidden_size)
            # where the first intermediate_size rows are w1, the next are w3
            intermediate_size = layer.w13_weight.shape[1] // 2
            for expert_id in range(layer.w13_weight.shape[0]):
                start = 0
                for shard_id in range(2):  # w1 and w3
                    # Dequantize using the original scale for this shard
                    dq_weight = per_tensor_dequantize(
                        layer.w13_weight[expert_id][
                            start : start + intermediate_size, :
                        ],
                        layer.w13_weight_scale[expert_id][shard_id],
                    )
                    # Requantize using the combined max scale

                    (
                        layer.w13_weight[expert_id][
                            start : start + intermediate_size, :
                        ],
                        _,
                    ) = scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])

                    start += intermediate_size

            # Update the scale parameter to be per-expert
            layer.w13_weight_scale = Parameter(max_w13_scales, requires_grad=False)
        else:
            layer.w13_weight_scale = Parameter(
                layer.w13_weight_scale.data, requires_grad=False
            )

    if hasattr(layer, "w2_weight_scale") and layer.w2_weight_scale is not None:
        layer.w2_weight_scale = Parameter(
            layer.w2_weight_scale.data, requires_grad=False
        )
    # Input scales must be equal for each expert in fp8 MoE layers.
    if hasattr(layer, "w13_input_scale") and layer.w13_input_scale is not None:
        layer.w13_input_scale = Parameter(
            layer.w13_input_scale.max(), requires_grad=False
        )
    if hasattr(layer, "w2_input_scale") and layer.w2_input_scale is not None:
        layer.w2_input_scale = Parameter(
            layer.w2_input_scale.max(), requires_grad=False
        )

    if self.flashinfer_moe_backend is not None:
        layer.w13_weight.data = swap_w13_to_w31(layer.w13_weight.data)
        register_moe_scaling_factors(layer)
        if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
            rotate_flashinfer_fp8_moe_weights(layer.w13_weight, layer.w2_weight)

select_gemm_impl

select_gemm_impl(
    prepare_finalize: FusedMoEPrepareAndFinalize,
    layer: Module,
) -> FusedMoEPermuteExpertsUnpermute
Source code in vllm/model_executor/layers/quantization/modelopt.py
def select_gemm_impl(
    self,
    prepare_finalize: mk.FusedMoEPrepareAndFinalize,
    layer: torch.nn.Module,
) -> mk.FusedMoEPermuteExpertsUnpermute:
    assert self.moe_quant_config is not None
    experts = select_cutlass_fp8_gemm_impl(
        self.moe,
        self.moe_quant_config,
    )
    logger.debug_once("Using %s", experts.__class__.__name__)
    return experts

ModelOptNvFp4Config

Bases: QuantizationConfig

Config class for ModelOpt FP4.

Source code in vllm/model_executor/layers/quantization/modelopt.py
class ModelOptNvFp4Config(QuantizationConfig):
    """Config class for ModelOpt FP4."""

    def __init__(
        self,
        is_checkpoint_nvfp4_serialized: bool,
        kv_cache_quant_algo: Optional[str],
        exclude_modules: list[str],
        group_size: int = 16,
    ) -> None:
        super().__init__()
        self.is_checkpoint_nvfp4_serialized = is_checkpoint_nvfp4_serialized
        if is_checkpoint_nvfp4_serialized:
            logger.warning(
                "Detected ModelOpt NVFP4 checkpoint. Please note that"
                " the format is experimental and could change in future."
            )

            self.group_size = group_size
            self.kv_cache_quant_algo = kv_cache_quant_algo
            self.exclude_modules = exclude_modules

    @classmethod
    def get_name(cls) -> QuantizationMethods:
        return "modelopt_fp4"

    @classmethod
    def get_supported_act_dtypes(cls) -> list[torch.dtype]:
        return [torch.bfloat16, torch.half, torch.float8_e4m3fn]

    @classmethod
    def get_min_capability(cls) -> int:
        return 80

    @classmethod
    def get_config_filenames(cls) -> list[str]:
        return ["hf_quant_config.json"]

    def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
        if self.exclude_modules is not None:
            self.exclude_modules = hf_to_vllm_mapper.apply_list(self.exclude_modules)

    @classmethod
    def override_quantization_method(
        cls, hf_quant_cfg, user_quant
    ) -> Optional[QuantizationMethods]:
        """Detect if this ModelOpt FP4 config should be used based on
        quantization config."""
        if hf_quant_cfg is None:
            return None

        # Use the community standard 'quant_method'
        quant_method = hf_quant_cfg.get("quant_method", "").lower()

        # Only proceed if the method is explicitly "modelopt"
        if quant_method != "modelopt":
            return None

        # Look for ModelOpt-specific config structure
        if "quantization" in hf_quant_cfg:
            quant_config = hf_quant_cfg["quantization"]
            if isinstance(quant_config, dict):
                quant_algo = quant_config.get("quant_algo", "")
                if "NVFP4" in quant_algo:
                    return "modelopt_fp4"
        else:
            # Check for compressed-tensors style config with specific
            # quant_algo field
            quant_algo = hf_quant_cfg.get("quant_algo", "")
            if isinstance(quant_algo, str) and "FP4" in quant_algo.upper():
                return "modelopt_fp4"

        return None

    @classmethod
    def from_config(cls, config: dict[str, Any]) -> "ModelOptNvFp4Config":
        # Handle both traditional ModelOpt format and compressed-tensors
        # style format
        if "quantization" in config:
            # Traditional ModelOpt format:
            # {"quantization": {"quant_algo": "..."}}
            quant_config = cls.get_from_keys(config, ["quantization"])
            if not isinstance(quant_config, dict):
                raise ValueError("Expected 'quantization' to be a dictionary in config")

            quant_method = quant_config.get("quant_algo", "")
            if not quant_method:
                raise ValueError("Missing 'quant_algo' in quantization config")

            # Handle kv_cache_quant_algo with proper type validation
            kv_cache_quant_algo_raw = quant_config.get("kv_cache_quant_algo")
            if kv_cache_quant_algo_raw is None:
                # No KV cache quantization by default
                kv_cache_quant_algo = None
            elif isinstance(kv_cache_quant_algo_raw, str):
                kv_cache_quant_algo = kv_cache_quant_algo_raw
            else:
                raise ValueError(
                    f"kv_cache_quant_algo must be a string, got "
                    f"{type(kv_cache_quant_algo_raw)}"
                )

            # Handle group_size with proper type validation
            group_size_raw = quant_config.get("group_size")
            if group_size_raw is None:
                group_size = 16  # Default value
            elif isinstance(group_size_raw, int):
                group_size = group_size_raw
            else:
                try:
                    group_size = int(group_size_raw)
                except (ValueError, TypeError):
                    raise ValueError(
                        f"group_size must be an integer, got {type(group_size_raw)}"
                    ) from None

            # "exclude_modules" is the key in the legacy hf_quant_config.json
            exclude_modules = quant_config.get("exclude_modules", [])
            if not isinstance(exclude_modules, list):
                raise ValueError(
                    f"exclude_modules must be a list, got {type(exclude_modules)}"
                )
        else:
            # Compressed-tensors style format:
            # {"quant_algo": "...", "quant_method": "modelopt"}
            quant_method = config.get("quant_algo", "")

            # Handle kv_cache_quant_algo with proper type validation
            kv_cache_quant_algo_raw = config.get("kv_cache_quant_algo")
            if kv_cache_quant_algo_raw is None:
                # No KV cache quantization by default
                kv_cache_quant_algo = None
            elif isinstance(kv_cache_quant_algo_raw, str):
                kv_cache_quant_algo = kv_cache_quant_algo_raw
            else:
                raise ValueError(
                    f"kv_cache_quant_algo must be a string, got "
                    f"{type(kv_cache_quant_algo_raw)}"
                )

            # Handle group_size with proper type validation
            group_size_raw = config.get("group_size")
            if group_size_raw is None:
                group_size = 16  # Default value
            elif isinstance(group_size_raw, int):
                group_size = group_size_raw
            else:
                try:
                    group_size = int(group_size_raw)
                except (ValueError, TypeError):
                    raise ValueError(
                        f"group_size must be an integer, got {type(group_size_raw)}"
                    ) from None

            # "ignore" is the key in config.json
            exclude_modules = config.get("ignore", [])
            if not isinstance(exclude_modules, list):
                raise ValueError(
                    f"exclude_modules must be a list, got {type(exclude_modules)}"
                )

        if quant_method not in QUANT_ALGOS:
            raise ValueError(
                f"ModelOpt currently only supports: {QUANT_ALGOS} "
                "quantizations in vLLM. Please check the "
                "`hf_quant_config.json` file for your model's "
                "quant configuration."
            )
        is_checkpoint_nvfp4_serialized = "NVFP4" in quant_method

        # For FP4, these fields are required
        if is_checkpoint_nvfp4_serialized and "quantization" in config:
            # Check if required fields are present in the quantization config
            quant_config = config["quantization"]
            required_fields = ["group_size", "kv_cache_quant_algo", "exclude_modules"]
            missing_fields = [
                field for field in required_fields if field not in quant_config
            ]
            if missing_fields:
                raise ValueError(
                    f"NVFP4 quantization requires the following fields in "
                    f"hf_quant_config.json: {missing_fields}"
                )

        return cls(
            is_checkpoint_nvfp4_serialized,
            kv_cache_quant_algo,
            exclude_modules,
            group_size,
        )

    def is_layer_excluded(self, prefix: str) -> bool:
        """
        Check if a layer should be excluded from quantization.
        Handles both exact matching (for fused layers) and pattern matching.
        """
        # First check exact matching with fused layer support
        if is_layer_skipped(prefix, self.exclude_modules, self.packed_modules_mapping):
            return True

        # Check regex pattern matching for patterns not caught by exact match
        import regex as re

        for pattern in self.exclude_modules:
            # Skip patterns that would be caught by exact matching
            if "*" in pattern or "." in pattern:
                regex_str = pattern.replace(".", r"\.").replace("*", r".*")
                if re.fullmatch(regex_str, prefix):
                    return True
        return False

    def get_quant_method(
        self, layer: torch.nn.Module, prefix: str
    ) -> Optional["QuantizeMethodBase"]:
        from vllm.attention.layer import Attention  # Avoid circular import

        skip_layer = self.is_layer_excluded(prefix)
        if isinstance(layer, LinearBase):
            if skip_layer:
                return UnquantizedLinearMethod()
            # Check if this is a vision model layer that should not be quantized
            if "vision_tower" in prefix or "vision_model" in prefix:
                return UnquantizedLinearMethod()
            return ModelOptNvFp4LinearMethod(self)
        elif isinstance(layer, Attention):
            return ModelOptFp8KVCacheMethod(self)
        elif isinstance(layer, FusedMoE):
            if skip_layer:
                return None
            return ModelOptNvFp4FusedMoE(self, layer.moe_config, layer)
        return None

exclude_modules instance-attribute

exclude_modules = exclude_modules

group_size instance-attribute

group_size = group_size

is_checkpoint_nvfp4_serialized instance-attribute

is_checkpoint_nvfp4_serialized = (
    is_checkpoint_nvfp4_serialized
)

kv_cache_quant_algo instance-attribute

kv_cache_quant_algo = kv_cache_quant_algo

__init__

__init__(
    is_checkpoint_nvfp4_serialized: bool,
    kv_cache_quant_algo: Optional[str],
    exclude_modules: list[str],
    group_size: int = 16,
) -> None
Source code in vllm/model_executor/layers/quantization/modelopt.py
def __init__(
    self,
    is_checkpoint_nvfp4_serialized: bool,
    kv_cache_quant_algo: Optional[str],
    exclude_modules: list[str],
    group_size: int = 16,
) -> None:
    super().__init__()
    self.is_checkpoint_nvfp4_serialized = is_checkpoint_nvfp4_serialized
    if is_checkpoint_nvfp4_serialized:
        logger.warning(
            "Detected ModelOpt NVFP4 checkpoint. Please note that"
            " the format is experimental and could change in future."
        )

        self.group_size = group_size
        self.kv_cache_quant_algo = kv_cache_quant_algo
        self.exclude_modules = exclude_modules

apply_vllm_mapper

apply_vllm_mapper(hf_to_vllm_mapper: WeightsMapper)
Source code in vllm/model_executor/layers/quantization/modelopt.py
def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
    if self.exclude_modules is not None:
        self.exclude_modules = hf_to_vllm_mapper.apply_list(self.exclude_modules)

from_config classmethod

from_config(config: dict[str, Any]) -> ModelOptNvFp4Config
Source code in vllm/model_executor/layers/quantization/modelopt.py
@classmethod
def from_config(cls, config: dict[str, Any]) -> "ModelOptNvFp4Config":
    # Handle both traditional ModelOpt format and compressed-tensors
    # style format
    if "quantization" in config:
        # Traditional ModelOpt format:
        # {"quantization": {"quant_algo": "..."}}
        quant_config = cls.get_from_keys(config, ["quantization"])
        if not isinstance(quant_config, dict):
            raise ValueError("Expected 'quantization' to be a dictionary in config")

        quant_method = quant_config.get("quant_algo", "")
        if not quant_method:
            raise ValueError("Missing 'quant_algo' in quantization config")

        # Handle kv_cache_quant_algo with proper type validation
        kv_cache_quant_algo_raw = quant_config.get("kv_cache_quant_algo")
        if kv_cache_quant_algo_raw is None:
            # No KV cache quantization by default
            kv_cache_quant_algo = None
        elif isinstance(kv_cache_quant_algo_raw, str):
            kv_cache_quant_algo = kv_cache_quant_algo_raw
        else:
            raise ValueError(
                f"kv_cache_quant_algo must be a string, got "
                f"{type(kv_cache_quant_algo_raw)}"
            )

        # Handle group_size with proper type validation
        group_size_raw = quant_config.get("group_size")
        if group_size_raw is None:
            group_size = 16  # Default value
        elif isinstance(group_size_raw, int):
            group_size = group_size_raw
        else:
            try:
                group_size = int(group_size_raw)
            except (ValueError, TypeError):
                raise ValueError(
                    f"group_size must be an integer, got {type(group_size_raw)}"
                ) from None

        # "exclude_modules" is the key in the legacy hf_quant_config.json
        exclude_modules = quant_config.get("exclude_modules", [])
        if not isinstance(exclude_modules, list):
            raise ValueError(
                f"exclude_modules must be a list, got {type(exclude_modules)}"
            )
    else:
        # Compressed-tensors style format:
        # {"quant_algo": "...", "quant_method": "modelopt"}
        quant_method = config.get("quant_algo", "")

        # Handle kv_cache_quant_algo with proper type validation
        kv_cache_quant_algo_raw = config.get("kv_cache_quant_algo")
        if kv_cache_quant_algo_raw is None:
            # No KV cache quantization by default
            kv_cache_quant_algo = None
        elif isinstance(kv_cache_quant_algo_raw, str):
            kv_cache_quant_algo = kv_cache_quant_algo_raw
        else:
            raise ValueError(
                f"kv_cache_quant_algo must be a string, got "
                f"{type(kv_cache_quant_algo_raw)}"
            )

        # Handle group_size with proper type validation
        group_size_raw = config.get("group_size")
        if group_size_raw is None:
            group_size = 16  # Default value
        elif isinstance(group_size_raw, int):
            group_size = group_size_raw
        else:
            try:
                group_size = int(group_size_raw)
            except (ValueError, TypeError):
                raise ValueError(
                    f"group_size must be an integer, got {type(group_size_raw)}"
                ) from None

        # "ignore" is the key in config.json
        exclude_modules = config.get("ignore", [])
        if not isinstance(exclude_modules, list):
            raise ValueError(
                f"exclude_modules must be a list, got {type(exclude_modules)}"
            )

    if quant_method not in QUANT_ALGOS:
        raise ValueError(
            f"ModelOpt currently only supports: {QUANT_ALGOS} "
            "quantizations in vLLM. Please check the "
            "`hf_quant_config.json` file for your model's "
            "quant configuration."
        )
    is_checkpoint_nvfp4_serialized = "NVFP4" in quant_method

    # For FP4, these fields are required
    if is_checkpoint_nvfp4_serialized and "quantization" in config:
        # Check if required fields are present in the quantization config
        quant_config = config["quantization"]
        required_fields = ["group_size", "kv_cache_quant_algo", "exclude_modules"]
        missing_fields = [
            field for field in required_fields if field not in quant_config
        ]
        if missing_fields:
            raise ValueError(
                f"NVFP4 quantization requires the following fields in "
                f"hf_quant_config.json: {missing_fields}"
            )

    return cls(
        is_checkpoint_nvfp4_serialized,
        kv_cache_quant_algo,
        exclude_modules,
        group_size,
    )

get_config_filenames classmethod

get_config_filenames() -> list[str]
Source code in vllm/model_executor/layers/quantization/modelopt.py
@classmethod
def get_config_filenames(cls) -> list[str]:
    return ["hf_quant_config.json"]

get_min_capability classmethod

get_min_capability() -> int
Source code in vllm/model_executor/layers/quantization/modelopt.py
@classmethod
def get_min_capability(cls) -> int:
    return 80

get_name classmethod

get_name() -> QuantizationMethods
Source code in vllm/model_executor/layers/quantization/modelopt.py
@classmethod
def get_name(cls) -> QuantizationMethods:
    return "modelopt_fp4"

get_quant_method

get_quant_method(
    layer: Module, prefix: str
) -> Optional[QuantizeMethodBase]
Source code in vllm/model_executor/layers/quantization/modelopt.py
def get_quant_method(
    self, layer: torch.nn.Module, prefix: str
) -> Optional["QuantizeMethodBase"]:
    from vllm.attention.layer import Attention  # Avoid circular import

    skip_layer = self.is_layer_excluded(prefix)
    if isinstance(layer, LinearBase):
        if skip_layer:
            return UnquantizedLinearMethod()
        # Check if this is a vision model layer that should not be quantized
        if "vision_tower" in prefix or "vision_model" in prefix:
            return UnquantizedLinearMethod()
        return ModelOptNvFp4LinearMethod(self)
    elif isinstance(layer, Attention):
        return ModelOptFp8KVCacheMethod(self)
    elif isinstance(layer, FusedMoE):
        if skip_layer:
            return None
        return ModelOptNvFp4FusedMoE(self, layer.moe_config, layer)
    return None

get_supported_act_dtypes classmethod

get_supported_act_dtypes() -> list[dtype]
Source code in vllm/model_executor/layers/quantization/modelopt.py
@classmethod
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
    return [torch.bfloat16, torch.half, torch.float8_e4m3fn]

is_layer_excluded

is_layer_excluded(prefix: str) -> bool

Check if a layer should be excluded from quantization. Handles both exact matching (for fused layers) and pattern matching.

Source code in vllm/model_executor/layers/quantization/modelopt.py
def is_layer_excluded(self, prefix: str) -> bool:
    """
    Check if a layer should be excluded from quantization.
    Handles both exact matching (for fused layers) and pattern matching.
    """
    # First check exact matching with fused layer support
    if is_layer_skipped(prefix, self.exclude_modules, self.packed_modules_mapping):
        return True

    # Check regex pattern matching for patterns not caught by exact match
    import regex as re

    for pattern in self.exclude_modules:
        # Skip patterns that would be caught by exact matching
        if "*" in pattern or "." in pattern:
            regex_str = pattern.replace(".", r"\.").replace("*", r".*")
            if re.fullmatch(regex_str, prefix):
                return True
    return False

override_quantization_method classmethod

override_quantization_method(
    hf_quant_cfg, user_quant
) -> Optional[QuantizationMethods]

Detect if this ModelOpt FP4 config should be used based on quantization config.

Source code in vllm/model_executor/layers/quantization/modelopt.py
@classmethod
def override_quantization_method(
    cls, hf_quant_cfg, user_quant
) -> Optional[QuantizationMethods]:
    """Detect if this ModelOpt FP4 config should be used based on
    quantization config."""
    if hf_quant_cfg is None:
        return None

    # Use the community standard 'quant_method'
    quant_method = hf_quant_cfg.get("quant_method", "").lower()

    # Only proceed if the method is explicitly "modelopt"
    if quant_method != "modelopt":
        return None

    # Look for ModelOpt-specific config structure
    if "quantization" in hf_quant_cfg:
        quant_config = hf_quant_cfg["quantization"]
        if isinstance(quant_config, dict):
            quant_algo = quant_config.get("quant_algo", "")
            if "NVFP4" in quant_algo:
                return "modelopt_fp4"
    else:
        # Check for compressed-tensors style config with specific
        # quant_algo field
        quant_algo = hf_quant_cfg.get("quant_algo", "")
        if isinstance(quant_algo, str) and "FP4" in quant_algo.upper():
            return "modelopt_fp4"

    return None

ModelOptNvFp4FusedMoE

Bases: FusedMoEMethodBase

MoE Method for FP4 Quantization. Args: quant_config: NVFP4 Quant Config

Source code in vllm/model_executor/layers/quantization/modelopt.py
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
    """
    MoE Method for FP4 Quantization.
    Args:
        quant_config: NVFP4 Quant Config
    """

    def __init__(
        self,
        quant_config: ModelOptNvFp4Config,
        moe: FusedMoEConfig,
        layer: torch.nn.Module,
    ) -> None:
        from vllm.model_executor.layers.quantization.utils.nvfp4_moe_support import (  # noqa: E501
            detect_nvfp4_moe_support,
        )

        super().__init__(moe)
        self.quant_config = quant_config
        self.layer = layer
        _nvfp4 = detect_nvfp4_moe_support(self.__class__.__name__)
        self.cutlass_nvfp4_supported = _nvfp4.cutlass_supported
        self.allow_flashinfer = _nvfp4.allow_flashinfer
        self.use_marlin = _nvfp4.use_marlin
        self.flashinfer_moe_backend = None
        self._cache_permute_indices: dict[torch.Size, torch.Tensor] = {}
        if self.allow_flashinfer:
            self.flashinfer_moe_backend = get_flashinfer_moe_backend()
            logger.info_once(
                f"Using FlashInfer {self.flashinfer_moe_backend.value} kernels"
                " for ModelOptNvFp4FusedMoE."
            )

    def maybe_make_prepare_finalize(self) -> Optional[mk.FusedMoEPrepareAndFinalize]:
        if self.use_marlin or (
            self.allow_flashinfer
            and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
        ):
            return None
        elif (
            self.allow_flashinfer
            and self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS
        ):
            # For now, fp4 moe only works with the flashinfer dispatcher.
            prepare_finalize = build_flashinfer_fp4_cutlass_moe_prepare_finalize(
                self.moe
            )
            logger.debug_once("%s", prepare_finalize.__class__.__name__)
            return prepare_finalize
        else:
            return super().maybe_make_prepare_finalize()

    def select_gemm_impl(
        self,
        prepare_finalize: mk.FusedMoEPrepareAndFinalize,
        layer: torch.nn.Module,
    ) -> mk.FusedMoEPermuteExpertsUnpermute:
        assert self.moe_quant_config is not None
        experts = select_nvfp4_gemm_impl(
            self.moe,
            self.moe_quant_config,
            allow_flashinfer=self.allow_flashinfer,
        )
        logger.debug_once("Using %s", experts.__class__.__name__)
        return experts

    def uses_weight_scale_2_pattern(self) -> bool:
        """
        FP4 variants use 'weight_scale_2' pattern for per-tensor weight scales.
        """
        return True

    def create_weights(
        self,
        layer: torch.nn.Module,
        num_experts: int,
        hidden_size: int,
        intermediate_size_per_partition: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        if not self.quant_config.is_checkpoint_nvfp4_serialized:
            raise ValueError(
                "NVFP4 quantization was selected, "
                " dynamic quantization is not supported."
            )

        layer.num_experts = num_experts
        layer.params_dtype = params_dtype
        layer.quant_config = self.quant_config
        weight_dtype = torch.uint8
        weight_scale_dtype = torch.float8_e4m3fn
        weight_loader = extra_weight_attrs.get("weight_loader")
        # GEMM 1
        w13_weight = ModelWeightParameter(
            data=torch.empty(
                num_experts,
                2 * intermediate_size_per_partition,
                # 2 fp4 items are packed in the input dimension
                hidden_size // 2,
                dtype=weight_dtype,
            ),
            input_dim=1,
            output_dim=2,
            weight_loader=weight_loader,
        )
        layer.register_parameter("w13_weight", w13_weight)

        # GEMM 2
        w2_weight = ModelWeightParameter(
            data=torch.empty(
                num_experts,
                hidden_size,
                # 2 fp4 items are packed in the input dimension
                intermediate_size_per_partition // 2,
                dtype=weight_dtype,
            ),
            input_dim=1,
            output_dim=2,
            weight_loader=weight_loader,
        )
        layer.register_parameter("w2_weight", w2_weight)

        w13_weight_scale = ModelWeightParameter(
            data=torch.empty(
                num_experts,
                2 * intermediate_size_per_partition,
                # 2 fp4 items are packed in the input dimension
                hidden_size // self.quant_config.group_size,
                dtype=weight_scale_dtype,
            ),
            input_dim=1,
            output_dim=2,
            weight_loader=weight_loader,
        )
        layer.register_parameter("w13_weight_scale", w13_weight_scale)

        w2_weight_scale = ModelWeightParameter(
            data=torch.empty(
                num_experts,
                hidden_size,
                # 2 fp4 items are packed in the input dimension
                intermediate_size_per_partition // self.quant_config.group_size,
                dtype=weight_scale_dtype,
            ),
            input_dim=1,
            output_dim=2,
            weight_loader=weight_loader,
        )
        layer.register_parameter("w2_weight_scale", w2_weight_scale)

        extra_weight_attrs.update(
            {"quant_method": FusedMoeWeightScaleSupported.BLOCK.value}
        )

        w13_weight_scale_2 = PerTensorScaleParameter(
            data=torch.empty(num_experts, 2, dtype=torch.float32),
            weight_loader=weight_loader,
        )
        layer.register_parameter("w13_weight_scale_2", w13_weight_scale_2)

        w2_weight_scale_2 = PerTensorScaleParameter(
            data=torch.empty(num_experts, dtype=torch.float32),
            weight_loader=weight_loader,
        )
        layer.register_parameter("w2_weight_scale_2", w2_weight_scale_2)

        extra_weight_attrs.update(
            {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
        )

        w13_input_scale = PerTensorScaleParameter(
            data=torch.empty(num_experts, 2, dtype=torch.float32),
            weight_loader=weight_loader,
        )
        layer.register_parameter("w13_input_scale", w13_input_scale)

        w2_input_scale = PerTensorScaleParameter(
            data=torch.empty(num_experts, dtype=torch.float32),
            weight_loader=weight_loader,
        )
        layer.register_parameter("w2_input_scale", w2_input_scale)

    def prepare_static_weights_for_trtllm_fp4_moe(
        self,
        # args_dequant,
        # args,
        gemm1_weights,
        gemm2_weights,
        gemm1_scales_linear_fp4_bytes,
        gemm2_scales_linear_fp4_bytes,
        hidden_size,
        intermediate_size,
        num_experts,
    ):
        from flashinfer import nvfp4_block_scale_interleave
        from flashinfer.fused_moe.core import (
            _maybe_get_cached_w2_permute_indices,
            _maybe_get_cached_w3_w1_permute_indices,
        )

        """Prepare quantized weights for kernel (done offline with weights)."""
        epilogue_tile_m = 128  # FIXME: this depends on the kernel internals

        # Convert quantized weights to proper formats
        gemm1_weights_fp4 = gemm1_weights.view(torch.float8_e4m3fn).reshape(
            num_experts, 2 * intermediate_size, hidden_size // 2
        )  # packed fp4
        gemm1_scales_linear_fp4 = gemm1_scales_linear_fp4_bytes.view(
            torch.float8_e4m3fn
        ).reshape(
            num_experts, 2 * intermediate_size, hidden_size // 16
        )  # fp8 scaling factors

        gemm2_weights_fp4 = gemm2_weights.view(torch.float8_e4m3fn).reshape(
            num_experts, hidden_size, intermediate_size // 2
        )  # packed fp4
        gemm2_scales_linear_fp4 = gemm2_scales_linear_fp4_bytes.view(
            torch.float8_e4m3fn
        ).reshape(
            num_experts, hidden_size, intermediate_size // 16
        )  # fp8 scaling factors

        gemm1_weights_fp4_shuffled = []
        gemm1_scales_fp4_shuffled = []
        gemm2_weights_fp4_shuffled = []
        gemm2_scales_fp4_shuffled = []
        for i in range(num_experts):
            # Calculate the permute indices for the following:
            # 1. Reorder rows of W1 and scales for fused gated activation
            # 2. Shuffle weights and scaling factors for transposed mma output
            # for both w3_w1 and w2 weights and scale factors
            permute_indices = _maybe_get_cached_w3_w1_permute_indices(
                self._cache_permute_indices,
                gemm1_weights_fp4[i].view(torch.uint8),
                epilogue_tile_m,
            )
            gemm1_weights_fp4_shuffled.append(
                gemm1_weights_fp4[i]
                .view(torch.uint8)[permute_indices.to(gemm1_weights_fp4.device)]
                .contiguous()
            )

            permute_sf_indices = _maybe_get_cached_w3_w1_permute_indices(
                self._cache_permute_indices,
                gemm1_scales_linear_fp4[i].view(torch.uint8),
                epilogue_tile_m,
                num_elts_per_sf=16,
            )
            gemm1_scales_fp4_shuffled.append(
                nvfp4_block_scale_interleave(
                    gemm1_scales_linear_fp4[i]
                    .view(torch.uint8)[
                        permute_sf_indices.to(gemm1_scales_linear_fp4.device)
                    ]
                    .contiguous()
                )
            )

            permute_indices = _maybe_get_cached_w2_permute_indices(
                self._cache_permute_indices,
                gemm2_weights_fp4[i].view(torch.uint8),
                epilogue_tile_m,
            )
            gemm2_weights_fp4_shuffled.append(
                gemm2_weights_fp4[i]
                .view(torch.uint8)[permute_indices.to(gemm2_weights_fp4.device)]
                .contiguous()
            )

            permute_sf_indices = _maybe_get_cached_w2_permute_indices(
                self._cache_permute_indices,
                gemm2_scales_linear_fp4[i].view(torch.uint8),
                epilogue_tile_m,
                num_elts_per_sf=16,
            )
            gemm2_scales_fp4_shuffled.append(
                nvfp4_block_scale_interleave(
                    gemm2_scales_linear_fp4[i]
                    .view(torch.uint8)[
                        permute_sf_indices.to(gemm2_scales_linear_fp4.device)
                    ]
                    .contiguous()
                )
            )

        # Stack weights for all experts
        gemm1_weights_fp4_shuffled = torch.stack(gemm1_weights_fp4_shuffled)
        gemm1_scales_fp4_shuffled = (
            torch.stack(gemm1_scales_fp4_shuffled)
            .view(torch.float8_e4m3fn)
            .reshape(num_experts, 2 * intermediate_size, hidden_size // 16)
        )

        gemm2_weights_fp4_shuffled = torch.stack(gemm2_weights_fp4_shuffled)
        gemm2_scales_fp4_shuffled = (
            torch.stack(gemm2_scales_fp4_shuffled)
            .view(torch.float8_e4m3fn)
            .reshape(num_experts, hidden_size, intermediate_size // 16)
        )
        return (
            gemm1_weights_fp4_shuffled,
            gemm1_scales_fp4_shuffled,
            gemm2_weights_fp4_shuffled,
            gemm2_scales_fp4_shuffled,
        )

    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
        # GEMM 1 processing
        gemm1_weight = layer.w13_weight.data
        gemm1_weight_scale = layer.w13_weight_scale.data

        if self.allow_flashinfer:
            gemm1_weight, gemm1_weight_scale = reorder_w1w3_to_w3w1(
                gemm1_weight, gemm1_weight_scale, dim=-2
            )

        layer.w13_weight = Parameter(gemm1_weight, requires_grad=False)
        layer.w13_weight_scale = Parameter(gemm1_weight_scale, requires_grad=False)

        # Common processing for w13_weight_scale_2
        if not torch.allclose(
            layer.w13_weight_scale_2[:, 0], layer.w13_weight_scale_2[:, 1]
        ):
            logger.warning_once(
                "w1_weight_scale_2 must match w3_weight_scale_2. "
                "Accuracy may be affected."
            )

        w13_weight_scale_2 = layer.w13_weight_scale_2[:, 0]
        layer.w13_weight_scale_2 = Parameter(w13_weight_scale_2, requires_grad=False)

        # Common processing for input scales and alphas
        w13_input_scale = layer.w13_input_scale.max(dim=1).values.to(torch.float32)
        layer.g1_alphas = Parameter(
            (w13_input_scale * w13_weight_scale_2).to(torch.float32),
            requires_grad=False,
        )

        # This is for quantization, so we need to invert it.
        layer.w13_input_scale_quant = Parameter(
            (1 / w13_input_scale).to(torch.float32), requires_grad=False
        )

        # GEMM 2 processing
        layer.g2_alphas = Parameter(
            (layer.w2_input_scale * layer.w2_weight_scale_2).to(torch.float32),
            requires_grad=False,
        )

        # This is for quantization, so we need to invert it.
        layer.w2_input_scale_quant = Parameter(
            (1 / layer.w2_input_scale).to(torch.float32), requires_grad=False
        )

        # TensorRT-LLM specific processing
        if (
            self.allow_flashinfer
            and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
        ):
            # Prepare static weights for TRT-LLM kernel
            # alternate: prepare_static_weight_layouts_for_trtllm_moe
            (
                gemm1_weights_fp4_shuffled,
                gemm1_scales_fp4_shuffled,
                gemm2_weights_fp4_shuffled,
                gemm2_scales_fp4_shuffled,
            ) = self.prepare_static_weights_for_trtllm_fp4_moe(
                layer.w13_weight,
                layer.w2_weight,
                layer.w13_weight_scale,
                layer.w2_weight_scale,
                layer.w2_weight.size(-2),  # hidden_size
                layer.w13_weight.size(-2) // 2,  # intermediate_size
                layer.w13_weight.size(0),  # num_experts
            )
            logger.debug_once("Finished shuffling weights for TRT-LLM MOE")

            layer.gemm1_weights_fp4_shuffled = Parameter(
                gemm1_weights_fp4_shuffled, requires_grad=False
            )
            layer.gemm2_weights_fp4_shuffled = Parameter(
                gemm2_weights_fp4_shuffled, requires_grad=False
            )
            layer.gemm1_scales_fp4_shuffled = Parameter(
                gemm1_scales_fp4_shuffled, requires_grad=False
            )
            layer.gemm2_scales_fp4_shuffled = Parameter(
                gemm2_scales_fp4_shuffled, requires_grad=False
            )

            # Additional parameter needed for TRT-LLM
            layer.g1_scale_c = Parameter(
                (layer.w2_input_scale_quant * layer.g1_alphas).to(torch.float32),
                requires_grad=False,
            )

            # Clean up weights that won't be used by TRT-LLM
            del layer.w2_weight
            del layer.w2_weight_scale
            del layer.w13_weight
            del layer.w13_weight_scale
        elif self.use_marlin:
            # Marlin processing
            prepare_moe_fp4_layer_for_marlin(layer)
            del layer.g1_alphas
            del layer.g2_alphas
            del layer.w13_input_scale_quant
            del layer.w2_input_scale_quant
        else:
            # Non-TRT-LLM processing (Cutlass or non-flashinfer)
            assert layer.w13_weight_scale.shape[2] % 16 == 0, (
                "Expected weight_scale.dim(1) to be divisible by 16"
            )
            assert layer.w13_weight_scale.dtype == torch.float8_e4m3fn, (
                "Weight Blockscale must be represented as FP8-E4M3"
            )
            w13_blockscale_swizzled = swizzle_blockscale(layer.w13_weight_scale)
            layer.w13_weight_scale = Parameter(
                w13_blockscale_swizzled, requires_grad=False
            )

            assert layer.w2_weight_scale.shape[2] % 16 == 0, (
                "Expected weight_scale.dim(1) to be divisible by 16"
            )
            assert layer.w2_weight_scale.dtype == torch.float8_e4m3fn, (
                "Weight Blockscale must be represented as FP8-E4M3"
            )
            w2_blockscale_swizzled = swizzle_blockscale(layer.w2_weight_scale)
            layer.w2_weight_scale = Parameter(
                w2_blockscale_swizzled, requires_grad=False
            )
            layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False)

    def get_fused_moe_quant_config(
        self, layer: torch.nn.Module
    ) -> Optional[FusedMoEQuantConfig]:
        if (
            self.use_marlin
            or self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
        ):
            return None

        return nvfp4_moe_quant_config(
            w1_scale=layer.w13_weight_scale,
            w2_scale=layer.w2_weight_scale,
            g1_alphas=layer.g1_alphas,
            g2_alphas=layer.g2_alphas,
            a1_gscale=layer.w13_input_scale_quant,
            a2_gscale=layer.w2_input_scale_quant,
        )

    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        router_logits: torch.Tensor,
        top_k: int,
        renormalize: bool,
        use_grouped_topk: bool = False,
        topk_group: Optional[int] = None,
        num_expert_group: Optional[int] = None,
        global_num_experts: int = -1,
        expert_map: Optional[torch.Tensor] = None,
        custom_routing_function: Optional[Callable] = None,
        scoring_func: str = "softmax",
        routed_scaling_factor: float = 1.0,
        e_score_correction_bias: Optional[torch.Tensor] = None,
        apply_router_weight_on_input: bool = False,
        activation: str = "silu",
        enable_eplb: bool = False,
        expert_load_view: Optional[torch.Tensor] = None,
        logical_to_physical_map: Optional[torch.Tensor] = None,
        logical_replica_count: Optional[torch.Tensor] = None,
    ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
        if enable_eplb:
            raise NotImplementedError(
                "EPLB not supported for `ModelOptNvFp4FusedMoE` yet."
            )
        assert activation == "silu", "Only SiLU activation is supported."

        if (
            self.allow_flashinfer
            and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
        ):
            import flashinfer

            from vllm.model_executor.models.llama4 import Llama4MoE

            assert self.fused_experts is None

            a1_gscale = layer.w13_input_scale_quant
            (hidden_states_fp4, hidden_states_scale_linear_fp4) = (
                flashinfer.fp4_quantize(
                    x,
                    a1_gscale,
                    is_sf_swizzled_layout=False,
                )
            )
            use_llama4_routing = (
                custom_routing_function is Llama4MoE.custom_routing_function
            )
            routing_method_type = flashinfer.RoutingMethodType.DeepSeekV3
            if use_llama4_routing:
                routing_method_type = flashinfer.RoutingMethodType.Llama4
            routing_bias = e_score_correction_bias
            if routing_bias is not None:
                routing_bias = routing_bias.to(torch.bfloat16)
            out = flashinfer.fused_moe.trtllm_fp4_block_scale_moe(
                routing_logits=router_logits
                if use_llama4_routing
                else router_logits.to(torch.float32),
                routing_bias=routing_bias,
                hidden_states=hidden_states_fp4,
                hidden_states_scale=hidden_states_scale_linear_fp4.view(
                    torch.float8_e4m3fn
                ).flatten(),
                gemm1_weights=layer.gemm1_weights_fp4_shuffled.data,
                gemm1_weights_scale=layer.gemm1_scales_fp4_shuffled.data.view(
                    torch.float8_e4m3fn
                ),
                gemm1_bias=None,
                gemm1_alpha=None,
                gemm1_beta=None,
                gemm1_clamp_limit=None,
                gemm2_weights=layer.gemm2_weights_fp4_shuffled.data,
                gemm2_weights_scale=layer.gemm2_scales_fp4_shuffled.data.view(
                    torch.float8_e4m3fn
                ),
                gemm2_bias=None,
                output1_scale_scalar=layer.g1_scale_c.data,
                output1_scale_gate_scalar=layer.g1_alphas.data,
                output2_scale_scalar=layer.g2_alphas.data,
                num_experts=global_num_experts,
                top_k=top_k,
                n_group=num_expert_group if num_expert_group is not None else 0,
                topk_group=topk_group if topk_group is not None else 0,
                intermediate_size=layer.intermediate_size_per_partition,
                local_expert_offset=layer.ep_rank * layer.local_num_experts,
                local_num_experts=layer.local_num_experts,
                routed_scaling_factor=None,
                tile_tokens_dim=_get_tile_tokens_dim(
                    x.shape[0], top_k, layer.local_num_experts
                ),
                routing_method_type=routing_method_type,
                do_finalize=True,
            )[0]
            return out

        topk_weights, topk_ids, _ = FusedMoE.select_experts(
            hidden_states=x,
            router_logits=router_logits,
            use_grouped_topk=use_grouped_topk,
            top_k=top_k,
            renormalize=renormalize,
            topk_group=topk_group,
            num_expert_group=num_expert_group,
            custom_routing_function=custom_routing_function,
            scoring_func=scoring_func,
            routed_scaling_factor=routed_scaling_factor,
            e_score_correction_bias=e_score_correction_bias,
            indices_type=self.topk_indices_dtype,
        )

        #
        # Note: the order here is important. self.fused_experts can override
        # flashinfer cutlass, cutlass fp4 or fused_experts but not marlin or
        # trtllm.
        #
        if self.use_marlin:
            assert self.fused_experts is None
            return torch.ops.vllm.fused_marlin_moe(
                x,
                layer.w13_weight,
                layer.w2_weight,
                None,
                None,
                layer.w13_weight_scale,
                layer.w2_weight_scale,
                router_logits,
                topk_weights,
                topk_ids,
                global_scale1=layer.w13_weight_scale_2,
                global_scale2=layer.w2_weight_scale_2,
                quant_type_id=scalar_types.float4_e2m1f.id,
                apply_router_weight_on_input=apply_router_weight_on_input,
                global_num_experts=global_num_experts,
                expert_map=expert_map,
                workspace=layer.workspace,
            )

        elif self.fused_experts is not None:
            assert (
                self.allow_flashinfer
                and self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS
            )

            assert is_valid_flashinfer_cutlass_fused_moe(
                x, layer.w13_weight, layer.w2_weight
            ), "Flashinfer CUTLASS Fused MoE not applicable!"

            return self.fused_experts(
                hidden_states=x,
                w1=layer.w13_weight,
                w2=layer.w2_weight,
                topk_weights=topk_weights,
                topk_ids=topk_ids,
                inplace=False,  # TODO(shuw): fix later, now output is high prec
                activation=activation,
                global_num_experts=global_num_experts,
                expert_map=expert_map,
                apply_router_weight_on_input=apply_router_weight_on_input,
            )
        elif (
            self.allow_flashinfer
            and self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS
        ):
            from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (  # noqa: E501
                flashinfer_cutlass_moe_fp4,
            )

            assert self.moe_quant_config is not None

            return flashinfer_cutlass_moe_fp4(
                hidden_states=x,
                w1=layer.w13_weight,
                w2=layer.w2_weight,
                topk_weights=topk_weights,
                topk_ids=topk_ids,
                quant_config=self.moe_quant_config,
                inplace=False,
                activation=activation,
                global_num_experts=global_num_experts,
                expert_map=expert_map,
                apply_router_weight_on_input=apply_router_weight_on_input,
            )
        else:
            # If no modular kernel is provided, use cutlass_moe_fp4 for TP case
            # only (no EP).
            from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp4

            assert self.moe_quant_config is not None
            return cutlass_moe_fp4(
                a=x,
                w1_fp4=layer.w13_weight,
                w2_fp4=layer.w2_weight,
                topk_weights=topk_weights,
                topk_ids=topk_ids,
                quant_config=self.moe_quant_config,
                expert_map=expert_map,
                apply_router_weight_on_input=apply_router_weight_on_input,
                # TODO: derive from arguments
                m=x.shape[0],
                n=layer.w2_weight.shape[2] * 2,
                k=x.shape[1],
                e=layer.w13_weight.shape[0],
            )

_cache_permute_indices instance-attribute

_cache_permute_indices: dict[Size, Tensor] = {}

allow_flashinfer instance-attribute

allow_flashinfer = allow_flashinfer

cutlass_nvfp4_supported instance-attribute

cutlass_nvfp4_supported = cutlass_supported

flashinfer_moe_backend instance-attribute

flashinfer_moe_backend = None

layer instance-attribute

layer = layer

quant_config instance-attribute

quant_config = quant_config

use_marlin instance-attribute

use_marlin = use_marlin

__init__

__init__(
    quant_config: ModelOptNvFp4Config,
    moe: FusedMoEConfig,
    layer: Module,
) -> None
Source code in vllm/model_executor/layers/quantization/modelopt.py
def __init__(
    self,
    quant_config: ModelOptNvFp4Config,
    moe: FusedMoEConfig,
    layer: torch.nn.Module,
) -> None:
    from vllm.model_executor.layers.quantization.utils.nvfp4_moe_support import (  # noqa: E501
        detect_nvfp4_moe_support,
    )

    super().__init__(moe)
    self.quant_config = quant_config
    self.layer = layer
    _nvfp4 = detect_nvfp4_moe_support(self.__class__.__name__)
    self.cutlass_nvfp4_supported = _nvfp4.cutlass_supported
    self.allow_flashinfer = _nvfp4.allow_flashinfer
    self.use_marlin = _nvfp4.use_marlin
    self.flashinfer_moe_backend = None
    self._cache_permute_indices: dict[torch.Size, torch.Tensor] = {}
    if self.allow_flashinfer:
        self.flashinfer_moe_backend = get_flashinfer_moe_backend()
        logger.info_once(
            f"Using FlashInfer {self.flashinfer_moe_backend.value} kernels"
            " for ModelOptNvFp4FusedMoE."
        )

apply

apply(
    layer: Module,
    x: Tensor,
    router_logits: Tensor,
    top_k: int,
    renormalize: bool,
    use_grouped_topk: bool = False,
    topk_group: Optional[int] = None,
    num_expert_group: Optional[int] = None,
    global_num_experts: int = -1,
    expert_map: Optional[Tensor] = None,
    custom_routing_function: Optional[Callable] = None,
    scoring_func: str = "softmax",
    routed_scaling_factor: float = 1.0,
    e_score_correction_bias: Optional[Tensor] = None,
    apply_router_weight_on_input: bool = False,
    activation: str = "silu",
    enable_eplb: bool = False,
    expert_load_view: Optional[Tensor] = None,
    logical_to_physical_map: Optional[Tensor] = None,
    logical_replica_count: Optional[Tensor] = None,
) -> Union[Tensor, tuple[Tensor, Tensor]]
Source code in vllm/model_executor/layers/quantization/modelopt.py
def apply(
    self,
    layer: torch.nn.Module,
    x: torch.Tensor,
    router_logits: torch.Tensor,
    top_k: int,
    renormalize: bool,
    use_grouped_topk: bool = False,
    topk_group: Optional[int] = None,
    num_expert_group: Optional[int] = None,
    global_num_experts: int = -1,
    expert_map: Optional[torch.Tensor] = None,
    custom_routing_function: Optional[Callable] = None,
    scoring_func: str = "softmax",
    routed_scaling_factor: float = 1.0,
    e_score_correction_bias: Optional[torch.Tensor] = None,
    apply_router_weight_on_input: bool = False,
    activation: str = "silu",
    enable_eplb: bool = False,
    expert_load_view: Optional[torch.Tensor] = None,
    logical_to_physical_map: Optional[torch.Tensor] = None,
    logical_replica_count: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
    if enable_eplb:
        raise NotImplementedError(
            "EPLB not supported for `ModelOptNvFp4FusedMoE` yet."
        )
    assert activation == "silu", "Only SiLU activation is supported."

    if (
        self.allow_flashinfer
        and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
    ):
        import flashinfer

        from vllm.model_executor.models.llama4 import Llama4MoE

        assert self.fused_experts is None

        a1_gscale = layer.w13_input_scale_quant
        (hidden_states_fp4, hidden_states_scale_linear_fp4) = (
            flashinfer.fp4_quantize(
                x,
                a1_gscale,
                is_sf_swizzled_layout=False,
            )
        )
        use_llama4_routing = (
            custom_routing_function is Llama4MoE.custom_routing_function
        )
        routing_method_type = flashinfer.RoutingMethodType.DeepSeekV3
        if use_llama4_routing:
            routing_method_type = flashinfer.RoutingMethodType.Llama4
        routing_bias = e_score_correction_bias
        if routing_bias is not None:
            routing_bias = routing_bias.to(torch.bfloat16)
        out = flashinfer.fused_moe.trtllm_fp4_block_scale_moe(
            routing_logits=router_logits
            if use_llama4_routing
            else router_logits.to(torch.float32),
            routing_bias=routing_bias,
            hidden_states=hidden_states_fp4,
            hidden_states_scale=hidden_states_scale_linear_fp4.view(
                torch.float8_e4m3fn
            ).flatten(),
            gemm1_weights=layer.gemm1_weights_fp4_shuffled.data,
            gemm1_weights_scale=layer.gemm1_scales_fp4_shuffled.data.view(
                torch.float8_e4m3fn
            ),
            gemm1_bias=None,
            gemm1_alpha=None,
            gemm1_beta=None,
            gemm1_clamp_limit=None,
            gemm2_weights=layer.gemm2_weights_fp4_shuffled.data,
            gemm2_weights_scale=layer.gemm2_scales_fp4_shuffled.data.view(
                torch.float8_e4m3fn
            ),
            gemm2_bias=None,
            output1_scale_scalar=layer.g1_scale_c.data,
            output1_scale_gate_scalar=layer.g1_alphas.data,
            output2_scale_scalar=layer.g2_alphas.data,
            num_experts=global_num_experts,
            top_k=top_k,
            n_group=num_expert_group if num_expert_group is not None else 0,
            topk_group=topk_group if topk_group is not None else 0,
            intermediate_size=layer.intermediate_size_per_partition,
            local_expert_offset=layer.ep_rank * layer.local_num_experts,
            local_num_experts=layer.local_num_experts,
            routed_scaling_factor=None,
            tile_tokens_dim=_get_tile_tokens_dim(
                x.shape[0], top_k, layer.local_num_experts
            ),
            routing_method_type=routing_method_type,
            do_finalize=True,
        )[0]
        return out

    topk_weights, topk_ids, _ = FusedMoE.select_experts(
        hidden_states=x,
        router_logits=router_logits,
        use_grouped_topk=use_grouped_topk,
        top_k=top_k,
        renormalize=renormalize,
        topk_group=topk_group,
        num_expert_group=num_expert_group,
        custom_routing_function=custom_routing_function,
        scoring_func=scoring_func,
        routed_scaling_factor=routed_scaling_factor,
        e_score_correction_bias=e_score_correction_bias,
        indices_type=self.topk_indices_dtype,
    )

    #
    # Note: the order here is important. self.fused_experts can override
    # flashinfer cutlass, cutlass fp4 or fused_experts but not marlin or
    # trtllm.
    #
    if self.use_marlin:
        assert self.fused_experts is None
        return torch.ops.vllm.fused_marlin_moe(
            x,
            layer.w13_weight,
            layer.w2_weight,
            None,
            None,
            layer.w13_weight_scale,
            layer.w2_weight_scale,
            router_logits,
            topk_weights,
            topk_ids,
            global_scale1=layer.w13_weight_scale_2,
            global_scale2=layer.w2_weight_scale_2,
            quant_type_id=scalar_types.float4_e2m1f.id,
            apply_router_weight_on_input=apply_router_weight_on_input,
            global_num_experts=global_num_experts,
            expert_map=expert_map,
            workspace=layer.workspace,
        )

    elif self.fused_experts is not None:
        assert (
            self.allow_flashinfer
            and self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS
        )

        assert is_valid_flashinfer_cutlass_fused_moe(
            x, layer.w13_weight, layer.w2_weight
        ), "Flashinfer CUTLASS Fused MoE not applicable!"

        return self.fused_experts(
            hidden_states=x,
            w1=layer.w13_weight,
            w2=layer.w2_weight,
            topk_weights=topk_weights,
            topk_ids=topk_ids,
            inplace=False,  # TODO(shuw): fix later, now output is high prec
            activation=activation,
            global_num_experts=global_num_experts,
            expert_map=expert_map,
            apply_router_weight_on_input=apply_router_weight_on_input,
        )
    elif (
        self.allow_flashinfer
        and self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS
    ):
        from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (  # noqa: E501
            flashinfer_cutlass_moe_fp4,
        )

        assert self.moe_quant_config is not None

        return flashinfer_cutlass_moe_fp4(
            hidden_states=x,
            w1=layer.w13_weight,
            w2=layer.w2_weight,
            topk_weights=topk_weights,
            topk_ids=topk_ids,
            quant_config=self.moe_quant_config,
            inplace=False,
            activation=activation,
            global_num_experts=global_num_experts,
            expert_map=expert_map,
            apply_router_weight_on_input=apply_router_weight_on_input,
        )
    else:
        # If no modular kernel is provided, use cutlass_moe_fp4 for TP case
        # only (no EP).
        from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp4

        assert self.moe_quant_config is not None
        return cutlass_moe_fp4(
            a=x,
            w1_fp4=layer.w13_weight,
            w2_fp4=layer.w2_weight,
            topk_weights=topk_weights,
            topk_ids=topk_ids,
            quant_config=self.moe_quant_config,
            expert_map=expert_map,
            apply_router_weight_on_input=apply_router_weight_on_input,
            # TODO: derive from arguments
            m=x.shape[0],
            n=layer.w2_weight.shape[2] * 2,
            k=x.shape[1],
            e=layer.w13_weight.shape[0],
        )

create_weights

create_weights(
    layer: Module,
    num_experts: int,
    hidden_size: int,
    intermediate_size_per_partition: int,
    params_dtype: dtype,
    **extra_weight_attrs,
)
Source code in vllm/model_executor/layers/quantization/modelopt.py
def create_weights(
    self,
    layer: torch.nn.Module,
    num_experts: int,
    hidden_size: int,
    intermediate_size_per_partition: int,
    params_dtype: torch.dtype,
    **extra_weight_attrs,
):
    if not self.quant_config.is_checkpoint_nvfp4_serialized:
        raise ValueError(
            "NVFP4 quantization was selected, "
            " dynamic quantization is not supported."
        )

    layer.num_experts = num_experts
    layer.params_dtype = params_dtype
    layer.quant_config = self.quant_config
    weight_dtype = torch.uint8
    weight_scale_dtype = torch.float8_e4m3fn
    weight_loader = extra_weight_attrs.get("weight_loader")
    # GEMM 1
    w13_weight = ModelWeightParameter(
        data=torch.empty(
            num_experts,
            2 * intermediate_size_per_partition,
            # 2 fp4 items are packed in the input dimension
            hidden_size // 2,
            dtype=weight_dtype,
        ),
        input_dim=1,
        output_dim=2,
        weight_loader=weight_loader,
    )
    layer.register_parameter("w13_weight", w13_weight)

    # GEMM 2
    w2_weight = ModelWeightParameter(
        data=torch.empty(
            num_experts,
            hidden_size,
            # 2 fp4 items are packed in the input dimension
            intermediate_size_per_partition // 2,
            dtype=weight_dtype,
        ),
        input_dim=1,
        output_dim=2,
        weight_loader=weight_loader,
    )
    layer.register_parameter("w2_weight", w2_weight)

    w13_weight_scale = ModelWeightParameter(
        data=torch.empty(
            num_experts,
            2 * intermediate_size_per_partition,
            # 2 fp4 items are packed in the input dimension
            hidden_size // self.quant_config.group_size,
            dtype=weight_scale_dtype,
        ),
        input_dim=1,
        output_dim=2,
        weight_loader=weight_loader,
    )
    layer.register_parameter("w13_weight_scale", w13_weight_scale)

    w2_weight_scale = ModelWeightParameter(
        data=torch.empty(
            num_experts,
            hidden_size,
            # 2 fp4 items are packed in the input dimension
            intermediate_size_per_partition // self.quant_config.group_size,
            dtype=weight_scale_dtype,
        ),
        input_dim=1,
        output_dim=2,
        weight_loader=weight_loader,
    )
    layer.register_parameter("w2_weight_scale", w2_weight_scale)

    extra_weight_attrs.update(
        {"quant_method": FusedMoeWeightScaleSupported.BLOCK.value}
    )

    w13_weight_scale_2 = PerTensorScaleParameter(
        data=torch.empty(num_experts, 2, dtype=torch.float32),
        weight_loader=weight_loader,
    )
    layer.register_parameter("w13_weight_scale_2", w13_weight_scale_2)

    w2_weight_scale_2 = PerTensorScaleParameter(
        data=torch.empty(num_experts, dtype=torch.float32),
        weight_loader=weight_loader,
    )
    layer.register_parameter("w2_weight_scale_2", w2_weight_scale_2)

    extra_weight_attrs.update(
        {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
    )

    w13_input_scale = PerTensorScaleParameter(
        data=torch.empty(num_experts, 2, dtype=torch.float32),
        weight_loader=weight_loader,
    )
    layer.register_parameter("w13_input_scale", w13_input_scale)

    w2_input_scale = PerTensorScaleParameter(
        data=torch.empty(num_experts, dtype=torch.float32),
        weight_loader=weight_loader,
    )
    layer.register_parameter("w2_input_scale", w2_input_scale)

get_fused_moe_quant_config

get_fused_moe_quant_config(
    layer: Module,
) -> Optional[FusedMoEQuantConfig]
Source code in vllm/model_executor/layers/quantization/modelopt.py
def get_fused_moe_quant_config(
    self, layer: torch.nn.Module
) -> Optional[FusedMoEQuantConfig]:
    if (
        self.use_marlin
        or self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
    ):
        return None

    return nvfp4_moe_quant_config(
        w1_scale=layer.w13_weight_scale,
        w2_scale=layer.w2_weight_scale,
        g1_alphas=layer.g1_alphas,
        g2_alphas=layer.g2_alphas,
        a1_gscale=layer.w13_input_scale_quant,
        a2_gscale=layer.w2_input_scale_quant,
    )

maybe_make_prepare_finalize

maybe_make_prepare_finalize() -> Optional[
    FusedMoEPrepareAndFinalize
]
Source code in vllm/model_executor/layers/quantization/modelopt.py
def maybe_make_prepare_finalize(self) -> Optional[mk.FusedMoEPrepareAndFinalize]:
    if self.use_marlin or (
        self.allow_flashinfer
        and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
    ):
        return None
    elif (
        self.allow_flashinfer
        and self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS
    ):
        # For now, fp4 moe only works with the flashinfer dispatcher.
        prepare_finalize = build_flashinfer_fp4_cutlass_moe_prepare_finalize(
            self.moe
        )
        logger.debug_once("%s", prepare_finalize.__class__.__name__)
        return prepare_finalize
    else:
        return super().maybe_make_prepare_finalize()

prepare_static_weights_for_trtllm_fp4_moe

prepare_static_weights_for_trtllm_fp4_moe(
    gemm1_weights,
    gemm2_weights,
    gemm1_scales_linear_fp4_bytes,
    gemm2_scales_linear_fp4_bytes,
    hidden_size,
    intermediate_size,
    num_experts,
)
Source code in vllm/model_executor/layers/quantization/modelopt.py
def prepare_static_weights_for_trtllm_fp4_moe(
    self,
    # args_dequant,
    # args,
    gemm1_weights,
    gemm2_weights,
    gemm1_scales_linear_fp4_bytes,
    gemm2_scales_linear_fp4_bytes,
    hidden_size,
    intermediate_size,
    num_experts,
):
    from flashinfer import nvfp4_block_scale_interleave
    from flashinfer.fused_moe.core import (
        _maybe_get_cached_w2_permute_indices,
        _maybe_get_cached_w3_w1_permute_indices,
    )

    """Prepare quantized weights for kernel (done offline with weights)."""
    epilogue_tile_m = 128  # FIXME: this depends on the kernel internals

    # Convert quantized weights to proper formats
    gemm1_weights_fp4 = gemm1_weights.view(torch.float8_e4m3fn).reshape(
        num_experts, 2 * intermediate_size, hidden_size // 2
    )  # packed fp4
    gemm1_scales_linear_fp4 = gemm1_scales_linear_fp4_bytes.view(
        torch.float8_e4m3fn
    ).reshape(
        num_experts, 2 * intermediate_size, hidden_size // 16
    )  # fp8 scaling factors

    gemm2_weights_fp4 = gemm2_weights.view(torch.float8_e4m3fn).reshape(
        num_experts, hidden_size, intermediate_size // 2
    )  # packed fp4
    gemm2_scales_linear_fp4 = gemm2_scales_linear_fp4_bytes.view(
        torch.float8_e4m3fn
    ).reshape(
        num_experts, hidden_size, intermediate_size // 16
    )  # fp8 scaling factors

    gemm1_weights_fp4_shuffled = []
    gemm1_scales_fp4_shuffled = []
    gemm2_weights_fp4_shuffled = []
    gemm2_scales_fp4_shuffled = []
    for i in range(num_experts):
        # Calculate the permute indices for the following:
        # 1. Reorder rows of W1 and scales for fused gated activation
        # 2. Shuffle weights and scaling factors for transposed mma output
        # for both w3_w1 and w2 weights and scale factors
        permute_indices = _maybe_get_cached_w3_w1_permute_indices(
            self._cache_permute_indices,
            gemm1_weights_fp4[i].view(torch.uint8),
            epilogue_tile_m,
        )
        gemm1_weights_fp4_shuffled.append(
            gemm1_weights_fp4[i]
            .view(torch.uint8)[permute_indices.to(gemm1_weights_fp4.device)]
            .contiguous()
        )

        permute_sf_indices = _maybe_get_cached_w3_w1_permute_indices(
            self._cache_permute_indices,
            gemm1_scales_linear_fp4[i].view(torch.uint8),
            epilogue_tile_m,
            num_elts_per_sf=16,
        )
        gemm1_scales_fp4_shuffled.append(
            nvfp4_block_scale_interleave(
                gemm1_scales_linear_fp4[i]
                .view(torch.uint8)[
                    permute_sf_indices.to(gemm1_scales_linear_fp4.device)
                ]
                .contiguous()
            )
        )

        permute_indices = _maybe_get_cached_w2_permute_indices(
            self._cache_permute_indices,
            gemm2_weights_fp4[i].view(torch.uint8),
            epilogue_tile_m,
        )
        gemm2_weights_fp4_shuffled.append(
            gemm2_weights_fp4[i]
            .view(torch.uint8)[permute_indices.to(gemm2_weights_fp4.device)]
            .contiguous()
        )

        permute_sf_indices = _maybe_get_cached_w2_permute_indices(
            self._cache_permute_indices,
            gemm2_scales_linear_fp4[i].view(torch.uint8),
            epilogue_tile_m,
            num_elts_per_sf=16,
        )
        gemm2_scales_fp4_shuffled.append(
            nvfp4_block_scale_interleave(
                gemm2_scales_linear_fp4[i]
                .view(torch.uint8)[
                    permute_sf_indices.to(gemm2_scales_linear_fp4.device)
                ]
                .contiguous()
            )
        )

    # Stack weights for all experts
    gemm1_weights_fp4_shuffled = torch.stack(gemm1_weights_fp4_shuffled)
    gemm1_scales_fp4_shuffled = (
        torch.stack(gemm1_scales_fp4_shuffled)
        .view(torch.float8_e4m3fn)
        .reshape(num_experts, 2 * intermediate_size, hidden_size // 16)
    )

    gemm2_weights_fp4_shuffled = torch.stack(gemm2_weights_fp4_shuffled)
    gemm2_scales_fp4_shuffled = (
        torch.stack(gemm2_scales_fp4_shuffled)
        .view(torch.float8_e4m3fn)
        .reshape(num_experts, hidden_size, intermediate_size // 16)
    )
    return (
        gemm1_weights_fp4_shuffled,
        gemm1_scales_fp4_shuffled,
        gemm2_weights_fp4_shuffled,
        gemm2_scales_fp4_shuffled,
    )

process_weights_after_loading

process_weights_after_loading(layer: Module) -> None
Source code in vllm/model_executor/layers/quantization/modelopt.py
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
    # GEMM 1 processing
    gemm1_weight = layer.w13_weight.data
    gemm1_weight_scale = layer.w13_weight_scale.data

    if self.allow_flashinfer:
        gemm1_weight, gemm1_weight_scale = reorder_w1w3_to_w3w1(
            gemm1_weight, gemm1_weight_scale, dim=-2
        )

    layer.w13_weight = Parameter(gemm1_weight, requires_grad=False)
    layer.w13_weight_scale = Parameter(gemm1_weight_scale, requires_grad=False)

    # Common processing for w13_weight_scale_2
    if not torch.allclose(
        layer.w13_weight_scale_2[:, 0], layer.w13_weight_scale_2[:, 1]
    ):
        logger.warning_once(
            "w1_weight_scale_2 must match w3_weight_scale_2. "
            "Accuracy may be affected."
        )

    w13_weight_scale_2 = layer.w13_weight_scale_2[:, 0]
    layer.w13_weight_scale_2 = Parameter(w13_weight_scale_2, requires_grad=False)

    # Common processing for input scales and alphas
    w13_input_scale = layer.w13_input_scale.max(dim=1).values.to(torch.float32)
    layer.g1_alphas = Parameter(
        (w13_input_scale * w13_weight_scale_2).to(torch.float32),
        requires_grad=False,
    )

    # This is for quantization, so we need to invert it.
    layer.w13_input_scale_quant = Parameter(
        (1 / w13_input_scale).to(torch.float32), requires_grad=False
    )

    # GEMM 2 processing
    layer.g2_alphas = Parameter(
        (layer.w2_input_scale * layer.w2_weight_scale_2).to(torch.float32),
        requires_grad=False,
    )

    # This is for quantization, so we need to invert it.
    layer.w2_input_scale_quant = Parameter(
        (1 / layer.w2_input_scale).to(torch.float32), requires_grad=False
    )

    # TensorRT-LLM specific processing
    if (
        self.allow_flashinfer
        and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
    ):
        # Prepare static weights for TRT-LLM kernel
        # alternate: prepare_static_weight_layouts_for_trtllm_moe
        (
            gemm1_weights_fp4_shuffled,
            gemm1_scales_fp4_shuffled,
            gemm2_weights_fp4_shuffled,
            gemm2_scales_fp4_shuffled,
        ) = self.prepare_static_weights_for_trtllm_fp4_moe(
            layer.w13_weight,
            layer.w2_weight,
            layer.w13_weight_scale,
            layer.w2_weight_scale,
            layer.w2_weight.size(-2),  # hidden_size
            layer.w13_weight.size(-2) // 2,  # intermediate_size
            layer.w13_weight.size(0),  # num_experts
        )
        logger.debug_once("Finished shuffling weights for TRT-LLM MOE")

        layer.gemm1_weights_fp4_shuffled = Parameter(
            gemm1_weights_fp4_shuffled, requires_grad=False
        )
        layer.gemm2_weights_fp4_shuffled = Parameter(
            gemm2_weights_fp4_shuffled, requires_grad=False
        )
        layer.gemm1_scales_fp4_shuffled = Parameter(
            gemm1_scales_fp4_shuffled, requires_grad=False
        )
        layer.gemm2_scales_fp4_shuffled = Parameter(
            gemm2_scales_fp4_shuffled, requires_grad=False
        )

        # Additional parameter needed for TRT-LLM
        layer.g1_scale_c = Parameter(
            (layer.w2_input_scale_quant * layer.g1_alphas).to(torch.float32),
            requires_grad=False,
        )

        # Clean up weights that won't be used by TRT-LLM
        del layer.w2_weight
        del layer.w2_weight_scale
        del layer.w13_weight
        del layer.w13_weight_scale
    elif self.use_marlin:
        # Marlin processing
        prepare_moe_fp4_layer_for_marlin(layer)
        del layer.g1_alphas
        del layer.g2_alphas
        del layer.w13_input_scale_quant
        del layer.w2_input_scale_quant
    else:
        # Non-TRT-LLM processing (Cutlass or non-flashinfer)
        assert layer.w13_weight_scale.shape[2] % 16 == 0, (
            "Expected weight_scale.dim(1) to be divisible by 16"
        )
        assert layer.w13_weight_scale.dtype == torch.float8_e4m3fn, (
            "Weight Blockscale must be represented as FP8-E4M3"
        )
        w13_blockscale_swizzled = swizzle_blockscale(layer.w13_weight_scale)
        layer.w13_weight_scale = Parameter(
            w13_blockscale_swizzled, requires_grad=False
        )

        assert layer.w2_weight_scale.shape[2] % 16 == 0, (
            "Expected weight_scale.dim(1) to be divisible by 16"
        )
        assert layer.w2_weight_scale.dtype == torch.float8_e4m3fn, (
            "Weight Blockscale must be represented as FP8-E4M3"
        )
        w2_blockscale_swizzled = swizzle_blockscale(layer.w2_weight_scale)
        layer.w2_weight_scale = Parameter(
            w2_blockscale_swizzled, requires_grad=False
        )
        layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False)

select_gemm_impl

select_gemm_impl(
    prepare_finalize: FusedMoEPrepareAndFinalize,
    layer: Module,
) -> FusedMoEPermuteExpertsUnpermute
Source code in vllm/model_executor/layers/quantization/modelopt.py
def select_gemm_impl(
    self,
    prepare_finalize: mk.FusedMoEPrepareAndFinalize,
    layer: torch.nn.Module,
) -> mk.FusedMoEPermuteExpertsUnpermute:
    assert self.moe_quant_config is not None
    experts = select_nvfp4_gemm_impl(
        self.moe,
        self.moe_quant_config,
        allow_flashinfer=self.allow_flashinfer,
    )
    logger.debug_once("Using %s", experts.__class__.__name__)
    return experts

uses_weight_scale_2_pattern

uses_weight_scale_2_pattern() -> bool

FP4 variants use 'weight_scale_2' pattern for per-tensor weight scales.

Source code in vllm/model_executor/layers/quantization/modelopt.py
def uses_weight_scale_2_pattern(self) -> bool:
    """
    FP4 variants use 'weight_scale_2' pattern for per-tensor weight scales.
    """
    return True

ModelOptNvFp4LinearMethod

Bases: LinearMethodBase

Linear method for Model Optimizer NVFP4. Supports loading NVFP4 checkpoints with the following structure:

input_scale: torch.float32, scalar , weight: NVFP4(represented as byte) Shape: [1, X, y/2] weight_scale: FP8-E4M3, Shape: [X, Y], aka per block scale, weight_scale_2: torch.float32, scalar, Args: quant_config: The ModelOpt quantization config.

Source code in vllm/model_executor/layers/quantization/modelopt.py
class ModelOptNvFp4LinearMethod(LinearMethodBase):
    """Linear method for Model Optimizer NVFP4.
    Supports loading NVFP4 checkpoints with the following structure:

    input_scale: torch.float32, scalar ,
    weight: NVFP4(represented as byte) Shape: [1, X, y/2]
    weight_scale: FP8-E4M3, Shape: [X, Y], aka per block scale,
    weight_scale_2: torch.float32, scalar,
    Args: quant_config: The ModelOpt quantization config.
    """

    def __init__(self, quant_config: ModelOptNvFp4Config) -> None:
        self.quant_config = quant_config

        if envs.VLLM_USE_TRTLLM_FP4_GEMM:
            assert has_flashinfer(), "TRTLLM FP4 GEMM requires FlashInfer"
            self.backend = "flashinfer-trtllm"
        elif has_flashinfer():
            self.backend = "flashinfer-cutlass"
        elif cutlass_fp4_supported():
            self.backend = "cutlass"
        elif is_fp4_marlin_supported():
            self.backend = "marlin"
        else:
            raise ValueError(
                "Current platform does not support NVFP4"
                " quantization. Please use Blackwell and"
                " above."
            )

    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
        output_partition_sizes: list[int],
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        del input_size, output_size
        if not self.quant_config.is_checkpoint_nvfp4_serialized:
            raise ValueError(
                "NVFP4 quantization was selected, "
                " dynamic quantization is not supported."
            )
        output_size_per_partition = sum(output_partition_sizes)
        weight_loader = extra_weight_attrs.get("weight_loader")
        layer.logical_widths = output_partition_sizes
        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition

        if input_size_per_partition % 16 != 0:
            raise ValueError(
                "Unsupported model when in features size is not multiple of 16"
            )
        # The nvfp4 weight is still represented as
        weight_dtype = (
            torch.float8_e4m3fn
            if self.quant_config.is_checkpoint_nvfp4_serialized
            else params_dtype
        )
        # Weight
        weight = ModelWeightParameter(
            data=torch.empty(
                # 2 fp4 items are packed in the input dimension
                layer.output_size_per_partition,
                layer.input_size_per_partition // 2,
                dtype=torch.uint8,
            ),
            input_dim=1,
            output_dim=0,
            weight_loader=weight_loader,
        )
        layer.register_parameter("weight", weight)

        # Input Weight Scale
        input_scale = PerTensorScaleParameter(
            data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
            weight_loader=weight_loader,
        )
        layer.register_parameter("input_scale", input_scale)

        # Global Weight Scale
        weight_scale_2 = PerTensorScaleParameter(
            data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
            weight_loader=weight_loader,
        )
        layer.register_parameter("weight_scale_2", weight_scale_2)

        # Per Block Weight Scale
        weight_scale = ModelWeightParameter(
            data=torch.empty(
                output_size_per_partition,
                input_size_per_partition // self.quant_config.group_size,
                dtype=weight_dtype,
            ),
            input_dim=1,
            output_dim=0,
            weight_loader=weight_loader,
        )

        layer.register_parameter("weight_scale", weight_scale)

    def process_weights_after_loading(self, layer: Module) -> None:
        # global scales:
        input_scale_2 = layer.input_scale.max().to(torch.float32)
        layer.input_scale = Parameter(input_scale_2, requires_grad=False)

        weight_scale_2 = layer.weight_scale_2.max().to(torch.float32)
        layer.weight_scale_2 = Parameter(weight_scale_2, requires_grad=False)

        layer.alpha = Parameter(
            layer.input_scale * layer.weight_scale_2, requires_grad=False
        )

        # Calculate `1 / input_scale` so that we don't need to do so at runtime
        layer.input_scale_inv = Parameter(
            (1 / layer.input_scale).to(torch.float32), requires_grad=False
        )

        # Swizzle the weight blockscale.
        # contracting dimension is input dimension
        # block_size = 16;
        assert layer.weight_scale.dtype == torch.float8_e4m3fn, (
            "Weight Block scale must be represented as FP8-E4M3"
        )

        if self.backend == "marlin":
            prepare_fp4_layer_for_marlin(layer)
            del layer.alpha
            del layer.input_scale
        elif self.backend == "flashinfer-trtllm":
            # FlashInfer TRTLLM FP4 GEMM requires a different weight layout.
            # FlashInfer provides nvfp4_quantize to quantize + shuffle the
            # layout but we use our own quantization so we have to call
            # shuffles ourselves.
            from flashinfer import shuffle_matrix_a, shuffle_matrix_sf_a

            weight = layer.weight.data
            weight_scale = layer.weight_scale.data

            epilogue_tile_m = 128
            weight = shuffle_matrix_a(weight.view(torch.uint8), epilogue_tile_m)
            weight_scale = (
                shuffle_matrix_sf_a(weight_scale.view(torch.uint8), epilogue_tile_m)
                .reshape(weight_scale.shape)
                .view(torch.float8_e4m3fn)
            )

            layer.weight_scale = Parameter(weight_scale, requires_grad=False)
            layer.weight = Parameter(weight, requires_grad=False)
        else:
            swizzled_weight_scale = swizzle_blockscale(layer.weight_scale)
            layer.weight_scale = Parameter(swizzled_weight_scale, requires_grad=False)
            layer.weight = Parameter(layer.weight.data, requires_grad=False)

    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        bias: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        if self.backend == "marlin":
            return apply_fp4_marlin_linear(
                input=x,
                weight=layer.weight,
                weight_scale=layer.weight_scale,
                weight_scale_2=layer.weight_scale_2,
                workspace=layer.workspace,
                size_n=layer.output_size_per_partition,
                size_k=layer.input_size_per_partition,
                bias=bias,
            )

        output_dtype = x.dtype
        output_shape = [x.shape[0], layer.weight.shape[0]]

        # quantize BF16 or FP16 to (FP4 and interleaved block scale)
        x_fp4, x_blockscale = scaled_fp4_quant(x, layer.input_scale_inv)

        # validate dtypes of quantized input, input block scale,
        # weight and weight_blockscale
        assert x_fp4.dtype == torch.uint8
        assert layer.weight.dtype == torch.uint8
        assert x_blockscale.dtype == torch.float8_e4m3fn
        assert layer.weight_scale.dtype == torch.float8_e4m3fn
        assert layer.alpha.dtype == torch.float32

        mm_args = (
            x_fp4,
            layer.weight,
            x_blockscale,
            layer.weight_scale,
            layer.alpha,
            output_dtype,
        )
        if self.backend == "flashinfer-trtllm":
            out = flashinfer_scaled_fp4_mm(*mm_args, backend="trtllm")
        elif self.backend == "flashinfer-cutlass":
            out = flashinfer_scaled_fp4_mm(*mm_args, backend="cutlass")
        else:
            out = cutlass_scaled_fp4_mm(*mm_args)

        if bias is not None:
            out = out + bias
        return out.view(*output_shape)

backend instance-attribute

backend = 'flashinfer-trtllm'

quant_config instance-attribute

quant_config = quant_config

__init__

__init__(quant_config: ModelOptNvFp4Config) -> None
Source code in vllm/model_executor/layers/quantization/modelopt.py
def __init__(self, quant_config: ModelOptNvFp4Config) -> None:
    self.quant_config = quant_config

    if envs.VLLM_USE_TRTLLM_FP4_GEMM:
        assert has_flashinfer(), "TRTLLM FP4 GEMM requires FlashInfer"
        self.backend = "flashinfer-trtllm"
    elif has_flashinfer():
        self.backend = "flashinfer-cutlass"
    elif cutlass_fp4_supported():
        self.backend = "cutlass"
    elif is_fp4_marlin_supported():
        self.backend = "marlin"
    else:
        raise ValueError(
            "Current platform does not support NVFP4"
            " quantization. Please use Blackwell and"
            " above."
        )

apply

apply(
    layer: Module, x: Tensor, bias: Optional[Tensor] = None
) -> Tensor
Source code in vllm/model_executor/layers/quantization/modelopt.py
def apply(
    self,
    layer: torch.nn.Module,
    x: torch.Tensor,
    bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
    if self.backend == "marlin":
        return apply_fp4_marlin_linear(
            input=x,
            weight=layer.weight,
            weight_scale=layer.weight_scale,
            weight_scale_2=layer.weight_scale_2,
            workspace=layer.workspace,
            size_n=layer.output_size_per_partition,
            size_k=layer.input_size_per_partition,
            bias=bias,
        )

    output_dtype = x.dtype
    output_shape = [x.shape[0], layer.weight.shape[0]]

    # quantize BF16 or FP16 to (FP4 and interleaved block scale)
    x_fp4, x_blockscale = scaled_fp4_quant(x, layer.input_scale_inv)

    # validate dtypes of quantized input, input block scale,
    # weight and weight_blockscale
    assert x_fp4.dtype == torch.uint8
    assert layer.weight.dtype == torch.uint8
    assert x_blockscale.dtype == torch.float8_e4m3fn
    assert layer.weight_scale.dtype == torch.float8_e4m3fn
    assert layer.alpha.dtype == torch.float32

    mm_args = (
        x_fp4,
        layer.weight,
        x_blockscale,
        layer.weight_scale,
        layer.alpha,
        output_dtype,
    )
    if self.backend == "flashinfer-trtllm":
        out = flashinfer_scaled_fp4_mm(*mm_args, backend="trtllm")
    elif self.backend == "flashinfer-cutlass":
        out = flashinfer_scaled_fp4_mm(*mm_args, backend="cutlass")
    else:
        out = cutlass_scaled_fp4_mm(*mm_args)

    if bias is not None:
        out = out + bias
    return out.view(*output_shape)

create_weights

create_weights(
    layer: Module,
    input_size_per_partition: int,
    output_partition_sizes: list[int],
    input_size: int,
    output_size: int,
    params_dtype: dtype,
    **extra_weight_attrs,
)
Source code in vllm/model_executor/layers/quantization/modelopt.py
def create_weights(
    self,
    layer: torch.nn.Module,
    input_size_per_partition: int,
    output_partition_sizes: list[int],
    input_size: int,
    output_size: int,
    params_dtype: torch.dtype,
    **extra_weight_attrs,
):
    del input_size, output_size
    if not self.quant_config.is_checkpoint_nvfp4_serialized:
        raise ValueError(
            "NVFP4 quantization was selected, "
            " dynamic quantization is not supported."
        )
    output_size_per_partition = sum(output_partition_sizes)
    weight_loader = extra_weight_attrs.get("weight_loader")
    layer.logical_widths = output_partition_sizes
    layer.input_size_per_partition = input_size_per_partition
    layer.output_size_per_partition = output_size_per_partition

    if input_size_per_partition % 16 != 0:
        raise ValueError(
            "Unsupported model when in features size is not multiple of 16"
        )
    # The nvfp4 weight is still represented as
    weight_dtype = (
        torch.float8_e4m3fn
        if self.quant_config.is_checkpoint_nvfp4_serialized
        else params_dtype
    )
    # Weight
    weight = ModelWeightParameter(
        data=torch.empty(
            # 2 fp4 items are packed in the input dimension
            layer.output_size_per_partition,
            layer.input_size_per_partition // 2,
            dtype=torch.uint8,
        ),
        input_dim=1,
        output_dim=0,
        weight_loader=weight_loader,
    )
    layer.register_parameter("weight", weight)

    # Input Weight Scale
    input_scale = PerTensorScaleParameter(
        data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
        weight_loader=weight_loader,
    )
    layer.register_parameter("input_scale", input_scale)

    # Global Weight Scale
    weight_scale_2 = PerTensorScaleParameter(
        data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
        weight_loader=weight_loader,
    )
    layer.register_parameter("weight_scale_2", weight_scale_2)

    # Per Block Weight Scale
    weight_scale = ModelWeightParameter(
        data=torch.empty(
            output_size_per_partition,
            input_size_per_partition // self.quant_config.group_size,
            dtype=weight_dtype,
        ),
        input_dim=1,
        output_dim=0,
        weight_loader=weight_loader,
    )

    layer.register_parameter("weight_scale", weight_scale)

process_weights_after_loading

process_weights_after_loading(layer: Module) -> None
Source code in vllm/model_executor/layers/quantization/modelopt.py
def process_weights_after_loading(self, layer: Module) -> None:
    # global scales:
    input_scale_2 = layer.input_scale.max().to(torch.float32)
    layer.input_scale = Parameter(input_scale_2, requires_grad=False)

    weight_scale_2 = layer.weight_scale_2.max().to(torch.float32)
    layer.weight_scale_2 = Parameter(weight_scale_2, requires_grad=False)

    layer.alpha = Parameter(
        layer.input_scale * layer.weight_scale_2, requires_grad=False
    )

    # Calculate `1 / input_scale` so that we don't need to do so at runtime
    layer.input_scale_inv = Parameter(
        (1 / layer.input_scale).to(torch.float32), requires_grad=False
    )

    # Swizzle the weight blockscale.
    # contracting dimension is input dimension
    # block_size = 16;
    assert layer.weight_scale.dtype == torch.float8_e4m3fn, (
        "Weight Block scale must be represented as FP8-E4M3"
    )

    if self.backend == "marlin":
        prepare_fp4_layer_for_marlin(layer)
        del layer.alpha
        del layer.input_scale
    elif self.backend == "flashinfer-trtllm":
        # FlashInfer TRTLLM FP4 GEMM requires a different weight layout.
        # FlashInfer provides nvfp4_quantize to quantize + shuffle the
        # layout but we use our own quantization so we have to call
        # shuffles ourselves.
        from flashinfer import shuffle_matrix_a, shuffle_matrix_sf_a

        weight = layer.weight.data
        weight_scale = layer.weight_scale.data

        epilogue_tile_m = 128
        weight = shuffle_matrix_a(weight.view(torch.uint8), epilogue_tile_m)
        weight_scale = (
            shuffle_matrix_sf_a(weight_scale.view(torch.uint8), epilogue_tile_m)
            .reshape(weight_scale.shape)
            .view(torch.float8_e4m3fn)
        )

        layer.weight_scale = Parameter(weight_scale, requires_grad=False)
        layer.weight = Parameter(weight, requires_grad=False)
    else:
        swizzled_weight_scale = swizzle_blockscale(layer.weight_scale)
        layer.weight_scale = Parameter(swizzled_weight_scale, requires_grad=False)
        layer.weight = Parameter(layer.weight.data, requires_grad=False)

_get_tile_tokens_dim

_get_tile_tokens_dim(
    num_tokens: int, top_k: int, num_experts: int
) -> int
Source code in vllm/model_executor/layers/quantization/modelopt.py
def _get_tile_tokens_dim(num_tokens: int, top_k: int, num_experts: int) -> int:
    # Guess tokens per expert assuming perfect expert distribution first.
    num_tokens_per_expert = (num_tokens * top_k) // num_experts
    # And pad the number to the next power of 2.
    tile_tokens_dim = next_power_of_2(num_tokens_per_expert)
    # Cap to 8-64 tokens per CTA tile as it's the range supported by the kernel.
    tile_tokens_dim = min(max(tile_tokens_dim, 8), 64)
    return tile_tokens_dim