Skip to content

vllm.model_executor.model_loader.default_loader

logger module-attribute

logger = init_logger(__name__)

DefaultModelLoader

Bases: BaseModelLoader

Model loader that can load different file types from disk.

Source code in vllm/model_executor/model_loader/default_loader.py
class DefaultModelLoader(BaseModelLoader):
    """Model loader that can load different file types from disk."""

    # default number of thread when enable multithread weight loading
    DEFAULT_NUM_THREADS = 8

    @dataclasses.dataclass
    class Source:
        """A source for weights."""

        model_or_path: str
        """The model ID or path."""

        revision: Optional[str]
        """The optional model revision."""

        prefix: str = ""
        """A prefix to prepend to all weights."""

        fall_back_to_pt: bool = True
        """Whether .pt weights can be used."""

        allow_patterns_overrides: Optional[list[str]] = None
        """If defined, weights will load exclusively using these patterns."""

    counter_before_loading_weights: float = 0.0
    counter_after_loading_weights: float = 0.0

    def __init__(self, load_config: LoadConfig):
        super().__init__(load_config)

        extra_config = load_config.model_loader_extra_config
        allowed_keys = {"enable_multithread_load", "num_threads"}
        unexpected_keys = set(extra_config.keys()) - allowed_keys

        if unexpected_keys:
            raise ValueError(
                f"Unexpected extra config keys for load format "
                f"{load_config.load_format}: "
                f"{unexpected_keys}"
            )

    def _prepare_weights(
        self,
        model_name_or_path: str,
        revision: Optional[str],
        fall_back_to_pt: bool,
        allow_patterns_overrides: Optional[list[str]],
    ) -> tuple[str, list[str], bool]:
        """Prepare weights for the model.

        If the model is not local, it will be downloaded."""
        model_name_or_path = (
            maybe_download_from_modelscope(model_name_or_path, revision)
            or model_name_or_path
        )

        is_local = os.path.isdir(model_name_or_path)
        load_format = self.load_config.load_format
        use_safetensors = False
        index_file = SAFE_WEIGHTS_INDEX_NAME
        # Some quantized models use .pt files for storing the weights.
        if load_format == "auto":
            allow_patterns = ["*.safetensors", "*.bin"]
        elif load_format == "safetensors" or load_format == "fastsafetensors":
            use_safetensors = True
            allow_patterns = ["*.safetensors"]
        elif load_format == "mistral":
            use_safetensors = True
            allow_patterns = ["consolidated*.safetensors"]
            index_file = "consolidated.safetensors.index.json"
        elif load_format == "pt":
            allow_patterns = ["*.pt"]
        elif load_format == "npcache":
            allow_patterns = ["*.bin"]
        else:
            raise ValueError(f"Unknown load_format: {load_format}")

        if fall_back_to_pt:
            allow_patterns += ["*.pt"]

        if allow_patterns_overrides is not None:
            allow_patterns = allow_patterns_overrides

        if not is_local:
            hf_folder = download_weights_from_hf(
                model_name_or_path,
                self.load_config.download_dir,
                allow_patterns,
                revision,
                ignore_patterns=self.load_config.ignore_patterns,
            )
        else:
            hf_folder = model_name_or_path

        hf_weights_files: list[str] = []
        for pattern in allow_patterns:
            hf_weights_files += glob.glob(os.path.join(hf_folder, pattern))
            if len(hf_weights_files) > 0:
                if pattern == "*.safetensors":
                    use_safetensors = True
                break

        if use_safetensors:
            # For models like Mistral-7B-Instruct-v0.3
            # there are both sharded safetensors files and a consolidated
            # safetensors file. Using both breaks.
            # Here, we download the `model.safetensors.index.json` and filter
            # any files not found in the index.
            if not is_local:
                download_safetensors_index_file_from_hf(
                    model_name_or_path,
                    index_file,
                    self.load_config.download_dir,
                    revision,
                )
            hf_weights_files = filter_duplicate_safetensors_files(
                hf_weights_files, hf_folder, index_file
            )
        else:
            hf_weights_files = filter_files_not_needed_for_inference(hf_weights_files)

        if len(hf_weights_files) == 0:
            raise RuntimeError(
                f"Cannot find any model weights with `{model_name_or_path}`"
            )

        return hf_folder, hf_weights_files, use_safetensors

    def _get_weights_iterator(
        self, source: "Source"
    ) -> Generator[tuple[str, torch.Tensor], None, None]:
        """Get an iterator for the model weights based on the load format."""
        extra_config = self.load_config.model_loader_extra_config
        hf_folder, hf_weights_files, use_safetensors = self._prepare_weights(
            source.model_or_path,
            source.revision,
            source.fall_back_to_pt,
            source.allow_patterns_overrides,
        )
        if self.load_config.load_format == "npcache":
            # Currently np_cache only support *.bin checkpoints
            assert use_safetensors is False
            weights_iterator = np_cache_weights_iterator(
                source.model_or_path,
                self.load_config.download_dir,
                hf_folder,
                hf_weights_files,
                self.load_config.use_tqdm_on_load,
            )
        elif use_safetensors:
            if self.load_config.load_format == "fastsafetensors":
                weights_iterator = fastsafetensors_weights_iterator(
                    hf_weights_files,
                    self.load_config.use_tqdm_on_load,
                )
            else:
                if extra_config.get("enable_multithread_load"):
                    weights_iterator = multi_thread_safetensors_weights_iterator(
                        hf_weights_files,
                        self.load_config.use_tqdm_on_load,
                        max_workers=extra_config.get(
                            "num_threads", self.DEFAULT_NUM_THREADS
                        ),
                    )
                else:
                    weights_iterator = safetensors_weights_iterator(
                        hf_weights_files,
                        self.load_config.use_tqdm_on_load,
                        self.load_config.safetensors_load_strategy,
                    )
        else:
            if extra_config.get("enable_multithread_load"):
                weights_iterator = multi_thread_pt_weights_iterator(
                    hf_weights_files,
                    self.load_config.use_tqdm_on_load,
                    self.load_config.pt_load_map_location,
                    max_workers=extra_config.get(
                        "num_threads", self.DEFAULT_NUM_THREADS
                    ),
                )
            else:
                weights_iterator = pt_weights_iterator(
                    hf_weights_files,
                    self.load_config.use_tqdm_on_load,
                    self.load_config.pt_load_map_location,
                )

        if current_platform.is_tpu():
            from vllm.platforms.tpu import USE_TPU_INFERENCE

            if not USE_TPU_INFERENCE:
                # In PyTorch XLA, we should call `torch_xla.sync`
                # frequently so that not too many ops are accumulated
                # in the XLA program.
                import torch_xla

                def _xla_weights_iterator(iterator: Generator):
                    for weights in iterator:
                        yield weights
                        torch_xla.sync(wait=False)

                weights_iterator = _xla_weights_iterator(weights_iterator)

        if self.counter_before_loading_weights == 0.0:
            self.counter_before_loading_weights = time.perf_counter()
        # Apply the prefix.
        return ((source.prefix + name, tensor) for (name, tensor) in weights_iterator)

    def get_all_weights(
        self,
        model_config: ModelConfig,
        model: nn.Module,
    ) -> Generator[tuple[str, torch.Tensor], None, None]:
        primary_weights = DefaultModelLoader.Source(
            model_config.model,
            model_config.revision,
            prefix="",
            fall_back_to_pt=getattr(model, "fall_back_to_pt_during_load", True),
            allow_patterns_overrides=getattr(model, "allow_patterns_overrides", None),
        )
        yield from self._get_weights_iterator(primary_weights)

        secondary_weights = cast(
            Iterable[DefaultModelLoader.Source],
            getattr(model, "secondary_weights", ()),
        )
        for source in secondary_weights:
            yield from self._get_weights_iterator(source)

    def download_model(self, model_config: ModelConfig) -> None:
        self._prepare_weights(
            model_config.model,
            model_config.revision,
            fall_back_to_pt=True,
            allow_patterns_overrides=None,
        )

    def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
        if model_config.quantization == "torchao" and torchao_version_at_least(
            "0.14.0"
        ):
            self.load_config.safetensors_load_strategy = "torchao"
        weights_to_load = {name for name, _ in model.named_parameters()}

        # if we don't have `model.weight_metadata_and_attr_saved` defined and
        # set to True, it means that this is either offline quantization case
        # or the first run of online quantization
        # see online_quantization.py for detailed notes
        offline_quantization_or_first_run_of_online_quantization = not getattr(
            model, "weight_metadata_and_attr_saved", False
        )

        if model_config.quantization is None:
            # model is not quantized
            loaded_weights = model.load_weights(
                self.get_all_weights(model_config, model)
            )
        elif offline_quantization_or_first_run_of_online_quantization:
            # case 1: offline quantized checkpoint
            # case 2: Step I1 first run of weight loading with
            # online quantization
            # see online_quantization.py for detailed notes
            loaded_weights = model.load_weights(
                self.get_all_weights(model_config, model)
            )
        else:
            # to avoid circular dependency
            from vllm.model_executor.model_loader.online_quantization import (
                load_weights_and_online_quantize,
            )

            # subsequent runs of weight loading with online
            # quantization
            loaded_weights = load_weights_and_online_quantize(self, model, model_config)

        self.counter_after_loading_weights = time.perf_counter()
        logger.info(
            "Loading weights took %.2f seconds",
            self.counter_after_loading_weights - self.counter_before_loading_weights,
        )
        # We only enable strict check for non-quantized models
        # that have loaded weights tracking currently.
        if model_config.quantization is None and loaded_weights is not None:
            weights_not_loaded = weights_to_load - loaded_weights
            if weights_not_loaded:
                raise ValueError(
                    "Following weights were not initialized from "
                    f"checkpoint: {weights_not_loaded}"
                )

DEFAULT_NUM_THREADS class-attribute instance-attribute

DEFAULT_NUM_THREADS = 8

counter_after_loading_weights class-attribute instance-attribute

counter_after_loading_weights: float = 0.0

counter_before_loading_weights class-attribute instance-attribute

counter_before_loading_weights: float = 0.0

Source dataclass

A source for weights.

Source code in vllm/model_executor/model_loader/default_loader.py
@dataclasses.dataclass
class Source:
    """A source for weights."""

    model_or_path: str
    """The model ID or path."""

    revision: Optional[str]
    """The optional model revision."""

    prefix: str = ""
    """A prefix to prepend to all weights."""

    fall_back_to_pt: bool = True
    """Whether .pt weights can be used."""

    allow_patterns_overrides: Optional[list[str]] = None
    """If defined, weights will load exclusively using these patterns."""

allow_patterns_overrides class-attribute instance-attribute

allow_patterns_overrides: Optional[list[str]] = None

If defined, weights will load exclusively using these patterns.

fall_back_to_pt class-attribute instance-attribute

fall_back_to_pt: bool = True

Whether .pt weights can be used.

model_or_path instance-attribute

model_or_path: str

The model ID or path.

prefix class-attribute instance-attribute

prefix: str = ''

A prefix to prepend to all weights.

revision instance-attribute

revision: Optional[str]

The optional model revision.

__init__

__init__(
    model_or_path: str,
    revision: Optional[str],
    prefix: str = "",
    fall_back_to_pt: bool = True,
    allow_patterns_overrides: Optional[list[str]] = None,
) -> None

__init__

__init__(load_config: LoadConfig)
Source code in vllm/model_executor/model_loader/default_loader.py
def __init__(self, load_config: LoadConfig):
    super().__init__(load_config)

    extra_config = load_config.model_loader_extra_config
    allowed_keys = {"enable_multithread_load", "num_threads"}
    unexpected_keys = set(extra_config.keys()) - allowed_keys

    if unexpected_keys:
        raise ValueError(
            f"Unexpected extra config keys for load format "
            f"{load_config.load_format}: "
            f"{unexpected_keys}"
        )

_get_weights_iterator

_get_weights_iterator(
    source: Source,
) -> Generator[tuple[str, Tensor], None, None]

Get an iterator for the model weights based on the load format.

Source code in vllm/model_executor/model_loader/default_loader.py
def _get_weights_iterator(
    self, source: "Source"
) -> Generator[tuple[str, torch.Tensor], None, None]:
    """Get an iterator for the model weights based on the load format."""
    extra_config = self.load_config.model_loader_extra_config
    hf_folder, hf_weights_files, use_safetensors = self._prepare_weights(
        source.model_or_path,
        source.revision,
        source.fall_back_to_pt,
        source.allow_patterns_overrides,
    )
    if self.load_config.load_format == "npcache":
        # Currently np_cache only support *.bin checkpoints
        assert use_safetensors is False
        weights_iterator = np_cache_weights_iterator(
            source.model_or_path,
            self.load_config.download_dir,
            hf_folder,
            hf_weights_files,
            self.load_config.use_tqdm_on_load,
        )
    elif use_safetensors:
        if self.load_config.load_format == "fastsafetensors":
            weights_iterator = fastsafetensors_weights_iterator(
                hf_weights_files,
                self.load_config.use_tqdm_on_load,
            )
        else:
            if extra_config.get("enable_multithread_load"):
                weights_iterator = multi_thread_safetensors_weights_iterator(
                    hf_weights_files,
                    self.load_config.use_tqdm_on_load,
                    max_workers=extra_config.get(
                        "num_threads", self.DEFAULT_NUM_THREADS
                    ),
                )
            else:
                weights_iterator = safetensors_weights_iterator(
                    hf_weights_files,
                    self.load_config.use_tqdm_on_load,
                    self.load_config.safetensors_load_strategy,
                )
    else:
        if extra_config.get("enable_multithread_load"):
            weights_iterator = multi_thread_pt_weights_iterator(
                hf_weights_files,
                self.load_config.use_tqdm_on_load,
                self.load_config.pt_load_map_location,
                max_workers=extra_config.get(
                    "num_threads", self.DEFAULT_NUM_THREADS
                ),
            )
        else:
            weights_iterator = pt_weights_iterator(
                hf_weights_files,
                self.load_config.use_tqdm_on_load,
                self.load_config.pt_load_map_location,
            )

    if current_platform.is_tpu():
        from vllm.platforms.tpu import USE_TPU_INFERENCE

        if not USE_TPU_INFERENCE:
            # In PyTorch XLA, we should call `torch_xla.sync`
            # frequently so that not too many ops are accumulated
            # in the XLA program.
            import torch_xla

            def _xla_weights_iterator(iterator: Generator):
                for weights in iterator:
                    yield weights
                    torch_xla.sync(wait=False)

            weights_iterator = _xla_weights_iterator(weights_iterator)

    if self.counter_before_loading_weights == 0.0:
        self.counter_before_loading_weights = time.perf_counter()
    # Apply the prefix.
    return ((source.prefix + name, tensor) for (name, tensor) in weights_iterator)

_prepare_weights

_prepare_weights(
    model_name_or_path: str,
    revision: Optional[str],
    fall_back_to_pt: bool,
    allow_patterns_overrides: Optional[list[str]],
) -> tuple[str, list[str], bool]

Prepare weights for the model.

If the model is not local, it will be downloaded.

Source code in vllm/model_executor/model_loader/default_loader.py
def _prepare_weights(
    self,
    model_name_or_path: str,
    revision: Optional[str],
    fall_back_to_pt: bool,
    allow_patterns_overrides: Optional[list[str]],
) -> tuple[str, list[str], bool]:
    """Prepare weights for the model.

    If the model is not local, it will be downloaded."""
    model_name_or_path = (
        maybe_download_from_modelscope(model_name_or_path, revision)
        or model_name_or_path
    )

    is_local = os.path.isdir(model_name_or_path)
    load_format = self.load_config.load_format
    use_safetensors = False
    index_file = SAFE_WEIGHTS_INDEX_NAME
    # Some quantized models use .pt files for storing the weights.
    if load_format == "auto":
        allow_patterns = ["*.safetensors", "*.bin"]
    elif load_format == "safetensors" or load_format == "fastsafetensors":
        use_safetensors = True
        allow_patterns = ["*.safetensors"]
    elif load_format == "mistral":
        use_safetensors = True
        allow_patterns = ["consolidated*.safetensors"]
        index_file = "consolidated.safetensors.index.json"
    elif load_format == "pt":
        allow_patterns = ["*.pt"]
    elif load_format == "npcache":
        allow_patterns = ["*.bin"]
    else:
        raise ValueError(f"Unknown load_format: {load_format}")

    if fall_back_to_pt:
        allow_patterns += ["*.pt"]

    if allow_patterns_overrides is not None:
        allow_patterns = allow_patterns_overrides

    if not is_local:
        hf_folder = download_weights_from_hf(
            model_name_or_path,
            self.load_config.download_dir,
            allow_patterns,
            revision,
            ignore_patterns=self.load_config.ignore_patterns,
        )
    else:
        hf_folder = model_name_or_path

    hf_weights_files: list[str] = []
    for pattern in allow_patterns:
        hf_weights_files += glob.glob(os.path.join(hf_folder, pattern))
        if len(hf_weights_files) > 0:
            if pattern == "*.safetensors":
                use_safetensors = True
            break

    if use_safetensors:
        # For models like Mistral-7B-Instruct-v0.3
        # there are both sharded safetensors files and a consolidated
        # safetensors file. Using both breaks.
        # Here, we download the `model.safetensors.index.json` and filter
        # any files not found in the index.
        if not is_local:
            download_safetensors_index_file_from_hf(
                model_name_or_path,
                index_file,
                self.load_config.download_dir,
                revision,
            )
        hf_weights_files = filter_duplicate_safetensors_files(
            hf_weights_files, hf_folder, index_file
        )
    else:
        hf_weights_files = filter_files_not_needed_for_inference(hf_weights_files)

    if len(hf_weights_files) == 0:
        raise RuntimeError(
            f"Cannot find any model weights with `{model_name_or_path}`"
        )

    return hf_folder, hf_weights_files, use_safetensors

download_model

download_model(model_config: ModelConfig) -> None
Source code in vllm/model_executor/model_loader/default_loader.py
def download_model(self, model_config: ModelConfig) -> None:
    self._prepare_weights(
        model_config.model,
        model_config.revision,
        fall_back_to_pt=True,
        allow_patterns_overrides=None,
    )

get_all_weights

get_all_weights(
    model_config: ModelConfig, model: Module
) -> Generator[tuple[str, Tensor], None, None]
Source code in vllm/model_executor/model_loader/default_loader.py
def get_all_weights(
    self,
    model_config: ModelConfig,
    model: nn.Module,
) -> Generator[tuple[str, torch.Tensor], None, None]:
    primary_weights = DefaultModelLoader.Source(
        model_config.model,
        model_config.revision,
        prefix="",
        fall_back_to_pt=getattr(model, "fall_back_to_pt_during_load", True),
        allow_patterns_overrides=getattr(model, "allow_patterns_overrides", None),
    )
    yield from self._get_weights_iterator(primary_weights)

    secondary_weights = cast(
        Iterable[DefaultModelLoader.Source],
        getattr(model, "secondary_weights", ()),
    )
    for source in secondary_weights:
        yield from self._get_weights_iterator(source)

load_weights

load_weights(
    model: Module, model_config: ModelConfig
) -> None
Source code in vllm/model_executor/model_loader/default_loader.py
def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
    if model_config.quantization == "torchao" and torchao_version_at_least(
        "0.14.0"
    ):
        self.load_config.safetensors_load_strategy = "torchao"
    weights_to_load = {name for name, _ in model.named_parameters()}

    # if we don't have `model.weight_metadata_and_attr_saved` defined and
    # set to True, it means that this is either offline quantization case
    # or the first run of online quantization
    # see online_quantization.py for detailed notes
    offline_quantization_or_first_run_of_online_quantization = not getattr(
        model, "weight_metadata_and_attr_saved", False
    )

    if model_config.quantization is None:
        # model is not quantized
        loaded_weights = model.load_weights(
            self.get_all_weights(model_config, model)
        )
    elif offline_quantization_or_first_run_of_online_quantization:
        # case 1: offline quantized checkpoint
        # case 2: Step I1 first run of weight loading with
        # online quantization
        # see online_quantization.py for detailed notes
        loaded_weights = model.load_weights(
            self.get_all_weights(model_config, model)
        )
    else:
        # to avoid circular dependency
        from vllm.model_executor.model_loader.online_quantization import (
            load_weights_and_online_quantize,
        )

        # subsequent runs of weight loading with online
        # quantization
        loaded_weights = load_weights_and_online_quantize(self, model, model_config)

    self.counter_after_loading_weights = time.perf_counter()
    logger.info(
        "Loading weights took %.2f seconds",
        self.counter_after_loading_weights - self.counter_before_loading_weights,
    )
    # We only enable strict check for non-quantized models
    # that have loaded weights tracking currently.
    if model_config.quantization is None and loaded_weights is not None:
        weights_not_loaded = weights_to_load - loaded_weights
        if weights_not_loaded:
            raise ValueError(
                "Following weights were not initialized from "
                f"checkpoint: {weights_not_loaded}"
            )