Skip to content

vllm.model_executor.models.transformers

Wrapper around transformers models

Style module-attribute

Style = Literal[
    "colwise",
    "colwise_rep",
    "rowwise",
    "rowwise_rep",
    "replicate",
]

logger module-attribute

logger = init_logger(__name__)

MultiModalDummyInputsBuilder

Bases: BaseDummyInputsBuilder[MultiModalProcessingInfo]

Source code in vllm/model_executor/models/transformers.py
class MultiModalDummyInputsBuilder(BaseDummyInputsBuilder[MultiModalProcessingInfo]):
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_images = mm_counts.get("image", 0)

        processor = self.info.get_hf_processor()
        if "gemma3" in processor.__class__.__name__.lower():
            image_token = processor.boi_token
        else:
            image_token = getattr(processor, "image_token", "")
        return image_token * num_images

    def get_dummy_mm_data(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
        mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
    ) -> MultiModalDataDict:
        num_images = mm_counts.get("image", 0)

        target_width, target_height = self.info.get_max_image_size()

        image_overrides = mm_options.get("image") if mm_options else None

        return {
            "image": self._get_dummy_images(
                width=target_width,
                height=target_height,
                num_images=num_images,
                overrides=image_overrides,
            ),
        }

get_dummy_mm_data

get_dummy_mm_data(
    seq_len: int,
    mm_counts: Mapping[str, int],
    mm_options: Optional[
        Mapping[str, BaseDummyOptions]
    ] = None,
) -> MultiModalDataDict
Source code in vllm/model_executor/models/transformers.py
def get_dummy_mm_data(
    self,
    seq_len: int,
    mm_counts: Mapping[str, int],
    mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
) -> MultiModalDataDict:
    num_images = mm_counts.get("image", 0)

    target_width, target_height = self.info.get_max_image_size()

    image_overrides = mm_options.get("image") if mm_options else None

    return {
        "image": self._get_dummy_images(
            width=target_width,
            height=target_height,
            num_images=num_images,
            overrides=image_overrides,
        ),
    }

get_dummy_text

get_dummy_text(mm_counts: Mapping[str, int]) -> str
Source code in vllm/model_executor/models/transformers.py
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
    num_images = mm_counts.get("image", 0)

    processor = self.info.get_hf_processor()
    if "gemma3" in processor.__class__.__name__.lower():
        image_token = processor.boi_token
    else:
        image_token = getattr(processor, "image_token", "")
    return image_token * num_images

MultiModalProcessingInfo

Bases: BaseProcessingInfo

Source code in vllm/model_executor/models/transformers.py
class MultiModalProcessingInfo(BaseProcessingInfo):
    def get_supported_mm_limits(self):
        return {"image": None}

    def get_mm_max_tokens_per_item(self, seq_len, mm_counts):
        return {"image": self.get_max_image_tokens()}

    def get_max_image_tokens(self) -> int:
        width, height = self.get_max_image_size()
        processor = self.get_hf_processor()
        multimodal_config = self.ctx.model_config.multimodal_config
        mm_processor_kwargs = multimodal_config.mm_processor_kwargs or {}
        mm_tokens = processor._get_num_multimodal_tokens(
            image_sizes=([height, width],), **mm_processor_kwargs
        )
        image_tokens = mm_tokens["num_image_tokens"][0]
        return image_tokens

    def get_max_image_size(self):
        return 10_000, 10_000  # hardcode for arbitrary very large size

get_max_image_size

get_max_image_size()
Source code in vllm/model_executor/models/transformers.py
def get_max_image_size(self):
    return 10_000, 10_000  # hardcode for arbitrary very large size

get_max_image_tokens

get_max_image_tokens() -> int
Source code in vllm/model_executor/models/transformers.py
def get_max_image_tokens(self) -> int:
    width, height = self.get_max_image_size()
    processor = self.get_hf_processor()
    multimodal_config = self.ctx.model_config.multimodal_config
    mm_processor_kwargs = multimodal_config.mm_processor_kwargs or {}
    mm_tokens = processor._get_num_multimodal_tokens(
        image_sizes=([height, width],), **mm_processor_kwargs
    )
    image_tokens = mm_tokens["num_image_tokens"][0]
    return image_tokens

get_mm_max_tokens_per_item

get_mm_max_tokens_per_item(seq_len, mm_counts)
Source code in vllm/model_executor/models/transformers.py
def get_mm_max_tokens_per_item(self, seq_len, mm_counts):
    return {"image": self.get_max_image_tokens()}

get_supported_mm_limits

get_supported_mm_limits()
Source code in vllm/model_executor/models/transformers.py
def get_supported_mm_limits(self):
    return {"image": None}

MultiModalProcessor

Bases: BaseMultiModalProcessor[MultiModalProcessingInfo]

Source code in vllm/model_executor/models/transformers.py
class MultiModalProcessor(BaseMultiModalProcessor[MultiModalProcessingInfo]):
    def _get_prompt_updates(
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
        out_mm_kwargs: MultiModalKwargsItems,
    ):
        """
        Given the original multi-modal items for this modality
        and HF-processed data, output the updates to perform.

        The information returned by this method is used to update token inputs
        which bypass the HF processor. It is also used to update the output of
        HF processor if the HF process does not apply prompt updates to text
        inputs.

        Moreover, this information is critical to determine the token positions
        in order to construct  :class:`~vllm-multimodal.input.PlaceholderRange`
        for each multi-modal item.
        """
        return None

    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        # HF Processors always return a mask but vLLM doesn't need it
        hf_inputs.pop("attention_mask", None)
        num_image_patches = hf_inputs.get("num_image_patches")
        mm_fields = {
            key: MultiModalFieldConfig.flat_from_sizes("image", num_image_patches)
            for key in hf_inputs
        }
        mm_fields["image_embeds"] = MultiModalFieldConfig.flat_from_sizes(
            "image", num_image_patches
        )

        # Keep these as batched, as they always have batch size as first dim
        mm_fields["image_grid_thw"] = MultiModalFieldConfig.batched("image")
        mm_fields["video_grid_thw"] = MultiModalFieldConfig.batched("image")
        mm_fields["num_image_patches"] = MultiModalFieldConfig.batched("image")
        return mm_fields

    def _get_hf_mm_data(
        self,
        mm_items: MultiModalDataItems,
    ) -> tuple[Mapping[str, object], Mapping[str, object]]:
        """
        In contrast to the base class, this method always adds
        `return_mm_token_type_ids` to the processor data
        """
        processor_data, passthrough_data = super()._get_hf_mm_data(mm_items)
        processor_data["return_mm_token_type_ids"] = True
        return processor_data, passthrough_data

    def apply(
        self,
        prompt: Union[str, list[int]],
        mm_data: MultiModalDataDict,
        hf_processor_mm_kwargs: Mapping[str, object],
        tokenization_kwargs: Optional[Mapping[str, object]] = None,
        mm_uuids: Optional[MultiModalUUIDDict] = None,
    ) -> MultiModalInputs:
        """
        Process multi-modal inputs to be used in vLLM.

        Apply HF Processor on prompt text and multi-modal data together,
        outputting token IDs and processed tensors.
        """
        if tokenization_kwargs is None:
            tokenization_kwargs = {}

        mm_items = self._to_mm_items(mm_data)
        hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
        if not isinstance(prompt, str):
            # the prompt is the tokenized ids which is not supported
            # by the hf_processor, which is why we would need to decode the ids
            # into string
            prompt = hf_processor.decode(prompt)

        # Bypass cached processor and always apply to the full set of mm inputs
        # NOTE: we can't just set caching=False because base class method
        # transforms outputs to `MultiModalKwargs` which is not going to
        # work for Transformers. We have a lot of logic tied to
        # `mm_tokens_per_modality` below
        prompt_ids, processed_data, _ = self._apply_hf_processor_text_mm(
            prompt_text=prompt,
            mm_items=mm_items,
            hf_processor_mm_kwargs=hf_processor_mm_kwargs,
            tokenization_kwargs=tokenization_kwargs,
        )

        # For gemma3 we check `token_type_ids` as the key
        token_type_key = (
            "mm_token_type_ids"
            if "mm_token_type_ids" in processed_data
            else "token_type_ids"
        )
        mm_token_type_ids = processed_data.pop(token_type_key)

        # We can infer vLLM style placeholder from token type ids, if we split
        # it for each input `mm_data`.
        mm_positions = torch.where(mm_token_type_ids == 1)[1]
        images = mm_items.get_items("image", ImageProcessorItems)
        multimodal_config = self.info.ctx.model_config.multimodal_config
        mm_processor_kwargs = multimodal_config.mm_processor_kwargs or {}
        image_sizes = []
        for item_idx in range(len(images)):
            image_size = images.get_image_size(item_idx)
            image_sizes.append((image_size.height, image_size.width))

        mm_tokens_per_modality = hf_processor._get_num_multimodal_tokens(
            image_sizes=image_sizes, **mm_processor_kwargs
        )

        mm_placeholders = {}
        split_sizes = mm_tokens_per_modality["num_image_tokens"]
        if split_sizes:
            chunked_mm_positions = torch.split(mm_positions, split_sizes)
            mm_tokens = torch.tensor(prompt_ids)[mm_token_type_ids[0].bool()]
            chunked_mm_tokens = torch.split(mm_tokens, split_sizes)
            ranges = [
                PlaceholderRange(
                    offset=positions[0].item(),
                    length=positions.shape[0],
                    is_embed=(mm_tokens == hf_processor.image_token_id).bool(),
                )
                for positions, mm_tokens in zip(chunked_mm_positions, chunked_mm_tokens)
            ]
            mm_placeholders = {"image": ranges}

        processed_data["num_image_patches"] = torch.tensor(
            mm_tokens_per_modality["num_image_patches"]
        )
        mm_kwargs = MultiModalKwargsItems.from_hf_inputs(
            processed_data,
            self._get_mm_fields_config(processed_data, hf_processor_mm_kwargs),
        )

        # Use overrides if provided; fallback to data-dependent hashing.
        mm_hashes = self._hash_mm_items(
            mm_items, hf_processor_mm_kwargs, tokenization_kwargs, mm_uuids=mm_uuids
        )

        return MultiModalInputs(
            type="multimodal",
            prompt_token_ids=prompt_ids,
            mm_kwargs=mm_kwargs,
            mm_hashes=mm_hashes,
            mm_placeholders=mm_placeholders,
        )

_get_hf_mm_data

_get_hf_mm_data(
    mm_items: MultiModalDataItems,
) -> tuple[Mapping[str, object], Mapping[str, object]]

In contrast to the base class, this method always adds return_mm_token_type_ids to the processor data

Source code in vllm/model_executor/models/transformers.py
def _get_hf_mm_data(
    self,
    mm_items: MultiModalDataItems,
) -> tuple[Mapping[str, object], Mapping[str, object]]:
    """
    In contrast to the base class, this method always adds
    `return_mm_token_type_ids` to the processor data
    """
    processor_data, passthrough_data = super()._get_hf_mm_data(mm_items)
    processor_data["return_mm_token_type_ids"] = True
    return processor_data, passthrough_data

_get_mm_fields_config

_get_mm_fields_config(
    hf_inputs: BatchFeature,
    hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]
Source code in vllm/model_executor/models/transformers.py
def _get_mm_fields_config(
    self,
    hf_inputs: BatchFeature,
    hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
    # HF Processors always return a mask but vLLM doesn't need it
    hf_inputs.pop("attention_mask", None)
    num_image_patches = hf_inputs.get("num_image_patches")
    mm_fields = {
        key: MultiModalFieldConfig.flat_from_sizes("image", num_image_patches)
        for key in hf_inputs
    }
    mm_fields["image_embeds"] = MultiModalFieldConfig.flat_from_sizes(
        "image", num_image_patches
    )

    # Keep these as batched, as they always have batch size as first dim
    mm_fields["image_grid_thw"] = MultiModalFieldConfig.batched("image")
    mm_fields["video_grid_thw"] = MultiModalFieldConfig.batched("image")
    mm_fields["num_image_patches"] = MultiModalFieldConfig.batched("image")
    return mm_fields

_get_prompt_updates

_get_prompt_updates(
    mm_items: MultiModalDataItems,
    hf_processor_mm_kwargs: Mapping[str, object],
    out_mm_kwargs: MultiModalKwargsItems,
)

Given the original multi-modal items for this modality and HF-processed data, output the updates to perform.

The information returned by this method is used to update token inputs which bypass the HF processor. It is also used to update the output of HF processor if the HF process does not apply prompt updates to text inputs.

Moreover, this information is critical to determine the token positions in order to construct :class:~vllm-multimodal.input.PlaceholderRange for each multi-modal item.

Source code in vllm/model_executor/models/transformers.py
def _get_prompt_updates(
    self,
    mm_items: MultiModalDataItems,
    hf_processor_mm_kwargs: Mapping[str, object],
    out_mm_kwargs: MultiModalKwargsItems,
):
    """
    Given the original multi-modal items for this modality
    and HF-processed data, output the updates to perform.

    The information returned by this method is used to update token inputs
    which bypass the HF processor. It is also used to update the output of
    HF processor if the HF process does not apply prompt updates to text
    inputs.

    Moreover, this information is critical to determine the token positions
    in order to construct  :class:`~vllm-multimodal.input.PlaceholderRange`
    for each multi-modal item.
    """
    return None

apply

apply(
    prompt: Union[str, list[int]],
    mm_data: MultiModalDataDict,
    hf_processor_mm_kwargs: Mapping[str, object],
    tokenization_kwargs: Optional[
        Mapping[str, object]
    ] = None,
    mm_uuids: Optional[MultiModalUUIDDict] = None,
) -> MultiModalInputs

Process multi-modal inputs to be used in vLLM.

Apply HF Processor on prompt text and multi-modal data together, outputting token IDs and processed tensors.

Source code in vllm/model_executor/models/transformers.py
def apply(
    self,
    prompt: Union[str, list[int]],
    mm_data: MultiModalDataDict,
    hf_processor_mm_kwargs: Mapping[str, object],
    tokenization_kwargs: Optional[Mapping[str, object]] = None,
    mm_uuids: Optional[MultiModalUUIDDict] = None,
) -> MultiModalInputs:
    """
    Process multi-modal inputs to be used in vLLM.

    Apply HF Processor on prompt text and multi-modal data together,
    outputting token IDs and processed tensors.
    """
    if tokenization_kwargs is None:
        tokenization_kwargs = {}

    mm_items = self._to_mm_items(mm_data)
    hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
    if not isinstance(prompt, str):
        # the prompt is the tokenized ids which is not supported
        # by the hf_processor, which is why we would need to decode the ids
        # into string
        prompt = hf_processor.decode(prompt)

    # Bypass cached processor and always apply to the full set of mm inputs
    # NOTE: we can't just set caching=False because base class method
    # transforms outputs to `MultiModalKwargs` which is not going to
    # work for Transformers. We have a lot of logic tied to
    # `mm_tokens_per_modality` below
    prompt_ids, processed_data, _ = self._apply_hf_processor_text_mm(
        prompt_text=prompt,
        mm_items=mm_items,
        hf_processor_mm_kwargs=hf_processor_mm_kwargs,
        tokenization_kwargs=tokenization_kwargs,
    )

    # For gemma3 we check `token_type_ids` as the key
    token_type_key = (
        "mm_token_type_ids"
        if "mm_token_type_ids" in processed_data
        else "token_type_ids"
    )
    mm_token_type_ids = processed_data.pop(token_type_key)

    # We can infer vLLM style placeholder from token type ids, if we split
    # it for each input `mm_data`.
    mm_positions = torch.where(mm_token_type_ids == 1)[1]
    images = mm_items.get_items("image", ImageProcessorItems)
    multimodal_config = self.info.ctx.model_config.multimodal_config
    mm_processor_kwargs = multimodal_config.mm_processor_kwargs or {}
    image_sizes = []
    for item_idx in range(len(images)):
        image_size = images.get_image_size(item_idx)
        image_sizes.append((image_size.height, image_size.width))

    mm_tokens_per_modality = hf_processor._get_num_multimodal_tokens(
        image_sizes=image_sizes, **mm_processor_kwargs
    )

    mm_placeholders = {}
    split_sizes = mm_tokens_per_modality["num_image_tokens"]
    if split_sizes:
        chunked_mm_positions = torch.split(mm_positions, split_sizes)
        mm_tokens = torch.tensor(prompt_ids)[mm_token_type_ids[0].bool()]
        chunked_mm_tokens = torch.split(mm_tokens, split_sizes)
        ranges = [
            PlaceholderRange(
                offset=positions[0].item(),
                length=positions.shape[0],
                is_embed=(mm_tokens == hf_processor.image_token_id).bool(),
            )
            for positions, mm_tokens in zip(chunked_mm_positions, chunked_mm_tokens)
        ]
        mm_placeholders = {"image": ranges}

    processed_data["num_image_patches"] = torch.tensor(
        mm_tokens_per_modality["num_image_patches"]
    )
    mm_kwargs = MultiModalKwargsItems.from_hf_inputs(
        processed_data,
        self._get_mm_fields_config(processed_data, hf_processor_mm_kwargs),
    )

    # Use overrides if provided; fallback to data-dependent hashing.
    mm_hashes = self._hash_mm_items(
        mm_items, hf_processor_mm_kwargs, tokenization_kwargs, mm_uuids=mm_uuids
    )

    return MultiModalInputs(
        type="multimodal",
        prompt_token_ids=prompt_ids,
        mm_kwargs=mm_kwargs,
        mm_hashes=mm_hashes,
        mm_placeholders=mm_placeholders,
    )

TransformersBase

Bases: Module, SupportsQuant, SupportsLoRA, SupportsPP

Source code in vllm/model_executor/models/transformers.py
class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
    embedding_padding_modules = ["lm_head"]
    embedding_modules = ["embed_tokens"]  # TODO transformers will have a util to get it

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        logger.info("Using Transformers backend.")

        self.config: PretrainedConfig = vllm_config.model_config.hf_config
        self.text_config: PretrainedConfig = self.config.get_text_config()
        self.cache_config: CacheConfig = vllm_config.cache_config
        self.device_config: DeviceConfig = vllm_config.device_config
        self.model_config: ModelConfig = vllm_config.model_config
        self.parallel_config: ParallelConfig = vllm_config.parallel_config
        self.quant_config: Optional[QuantizationConfig] = vllm_config.quant_config

        self.pp_group = get_pp_group()
        self.tp_group = get_tp_group()

        # Weights to skip in `self.load_weights`
        self.skip_prefixes: list[str] = []
        """Skip loading weights whose qualname starts with these prefixes."""
        self.skip_substrs: list[str] = []
        """Skip loading weights whose qualname contains these substrings."""
        self.ignore_unexpected_prefixes: list[str] = []
        """Ignore unexpected weights whose qualname starts with these prefixes.
        """
        self.ignore_unexpected_suffixes: list[str] = []
        """Ignore unexpected weights whose qualname ends with these suffixes."""

        if self.quant_config:
            quant_method_name = self.quant_config.get_name()
            # Check for unsupported quantization methods.
            if quant_method_name == "mxfp4":
                raise NotImplementedError(
                    "Transformers backend does not support MXFP4 quantization yet."
                )
            # Skip loading extra bias for GPTQ models.
            if "gptq" in quant_method_name:
                self.ignore_unexpected_suffixes.append(".bias")

        # Set correct attn and init on "meta" to delay allocating GPU tensors
        self.text_config._attn_implementation = "vllm"
        with init_on_device_without_buffers("meta"):
            self.model: PreTrainedModel = AutoModel.from_config(
                self.config,
                torch_dtype=self.model_config.dtype,
                trust_remote_code=self.model_config.trust_remote_code,
            )

        # Remove layers not on this pipeline parallel rank
        self.pipeline_parallel()
        # Substitute remaining layers with vLLM's layers as needed
        self.recursive_replace()
        # Create attention instances for KV cache allocation
        self.attention_instances = self.create_attention_instances()

        # Input embeddings
        input_embeddings = self.model.get_input_embeddings()
        if not isinstance(input_embeddings, PPMissingLayer):
            # Some models use embedding scales
            self.embed_scale = getattr(input_embeddings, "embed_scale", None)
            names = ("embedding_size", "hidden_size")
            embedding_dim = getattr_iter(self.text_config, names, None)
            assert embedding_dim is not None
            self.model.set_input_embeddings(
                VocabParallelEmbedding(
                    self.text_config.vocab_size,
                    embedding_dim=embedding_dim,
                    org_num_embeddings=self.text_config.vocab_size,
                    quant_config=self.quant_config,
                )
            )

        # Initialize any parameters that have not had their modules replaced
        self.init_parameters(self.model)

        # Pipeline parallel intermediate tensors
        self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
            ["hidden_states"], self.text_config.hidden_size
        )

    def pipeline_parallel(self):
        """
        Apply the model's pipeline parallelization plan.
        """
        if self.pp_group.world_size <= 1:
            return

        if not self.model.supports_pp_plan:
            tip = get_feature_request_tip(
                self.model_config.model, self.model_config.trust_remote_code
            )
            raise ValueError(
                f"{type(self.model)} does not support pipeline parallel. {tip}"
            )

        module_lists = []
        module_list_idx = None
        pp_plan = list(self.model._pp_plan.keys())
        for i, name in enumerate(pp_plan):
            if isinstance(getattr(self.model, name), nn.ModuleList):
                module_lists.append(name)
                module_list_idx = i

        if len(module_lists) > 1:
            raise ValueError(
                "Pipeline parallel of models with multiple `ModuleList`s "
                "in the base model are not supported yet!"
            )
        if module_list_idx is None:
            raise ValueError(f"Could not find `ModuleList` in {type(self.model)}")

        # Layers before module list
        for name in pp_plan[:module_list_idx]:
            if self.pp_group.is_first_rank or (
                self.text_config.tie_word_embeddings and self.pp_group.is_last_rank
            ):
                continue
            setattr(self.model, name, PPMissingLayer())

        # Module list
        start_layer, end_layer = get_pp_indices(
            self.text_config.num_hidden_layers,
            self.pp_group.rank_in_group,
            self.pp_group.world_size,
        )
        layers_name = pp_plan[module_list_idx]
        layers = getattr(self.model, layers_name)
        for i in range(len(layers)):
            if start_layer <= i and i < end_layer:
                continue
            layers[i] = PPMissingLayer()

        # Layers after module list
        for name in pp_plan[module_list_idx + 1 :]:
            # Modules that should be on last rank
            if not self.pp_group.is_last_rank:
                setattr(self.model, name, PPMissingLayer())

    def recursive_replace(self):
        """Recursively replace modules in the model as needed.

        Currently, this replaces:

        - `nn.Linear` with vLLM's tensor parallel linear classes
        - `*RMSNorm` with vLLM's `RMSNorm`
        """
        tp_plan = self.model.tp_plan

        if not tp_plan and self.tp_group.world_size > 1:
            tip = get_feature_request_tip(
                self.model_config.model, self.model_config.trust_remote_code
            )
            raise ValueError(
                f"{type(self.model)} does not support tensor parallel. {tip}"
            )

        # Prefix the patterns because we always start from `self.model`
        tp_plan = {maybe_prefix("model", k): v for k, v in tp_plan.items()}

        def _recursive_replace(module: nn.Module, prefix: str):
            for child_name, child_module in module.named_children():
                new_module = child_module
                qual_name = maybe_prefix(prefix, child_name)
                if isinstance(child_module, nn.Linear):
                    generator = (p for p in tp_plan if re.match(p, qual_name))
                    pattern = next(generator, None)
                    # Some weight loaders expect all linear layers to inherit
                    # LinearBase, so we set a default style which causes any
                    # unspecified layers to be replaced with ReplicatedLinear
                    style = tp_plan.get(pattern, "replicate")
                    new_module = replace_linear_class(
                        child_module, style, self.quant_config, prefix=qual_name
                    )
                # TODO(hmellor): Enable RMSNorm replacement once we have a way
                # to choose RMSNorm vs GemmaRMSNorm
                # elif child_module.__class__.__name__.endswith("RMSNorm"):
                #     new_module = replace_rms_norm_class(
                #         child_module, self.config.hidden_size)
                else:
                    _recursive_replace(child_module, prefix=qual_name)

                if new_module is not child_module:
                    setattr(module, child_name, new_module)
                    log_replacement(qual_name, child_module, new_module)

        _recursive_replace(self.model, prefix="model")

    def create_attention_instances(
        self, attn_type: AttentionType = AttentionType.DECODER
    ) -> dict[int, Attention]:
        """
        Create `Attention` instances to inform KV cache allocation.
        """
        num_heads = self.model_config.get_num_attention_heads(self.parallel_config)
        head_size = self.model_config.get_head_size()
        num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config)
        logits_soft_cap = getattr(self.text_config, "attn_logit_softcapping", None)
        start, end = get_pp_indices(
            self.text_config.num_hidden_layers,
            self.pp_group.rank_in_group,
            self.pp_group.world_size,
        )

        attention_instances = {}
        for i in range(start, end):
            # Handle interleaved sliding window attention
            per_layer_sliding_window = None
            if (
                hasattr(self.config, "layer_types")
                and self.config.layer_types[i] == "sliding_attention"
            ):
                per_layer_sliding_window = self.config.sliding_window

            attention_instances[i] = Attention(
                num_heads=num_heads,
                head_size=head_size,
                # NOTE: We use Llama scale as default, if it's set by
                # Transformers, it's updated in vllm_flash_attention_forward
                scale=head_size**-0.5,
                num_kv_heads=num_kv_heads,
                cache_config=self.cache_config,
                quant_config=self.quant_config,
                logits_soft_cap=logits_soft_cap,
                per_layer_sliding_window=per_layer_sliding_window,
                prefix=f"{i}.attn",
                attn_type=attn_type,
            )
        return attention_instances

    def init_parameters(self, module: nn.Module, dtype: Optional[torch.dtype] = None):
        """
        If a `parameter` is on the `meta` device, then its parent
        `module` is the original module created by:

        ```python
        with torch.device("meta"):
            self.model: PreTrainedModel = AutoModel.from_config(...)
        ```
        """

        def _init_parameters(module: nn.Module, dtype: Optional[torch.dtype]):
            for name, param in module.named_parameters(recurse=False):
                if param.device == torch.device("meta"):
                    new_param = nn.Parameter(
                        torch.empty_like(
                            param.data,
                            dtype=dtype or self.model_config.dtype,
                            device=self.device_config.device,
                        )
                    )
                    setattr(module, name, new_param)
            for child in module.children():
                _init_parameters(child, dtype)

        _init_parameters(module, dtype)

    def forward(
        self,
        input_ids: Optional[torch.Tensor],
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        **kwargs,
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if not self.pp_group.is_first_rank:
            assert intermediate_tensors is not None
            input_ids = None
            inputs_embeds = intermediate_tensors["hidden_states"]

        if input_ids is not None:
            input_ids = input_ids[None, ...]
        if inputs_embeds is not None:
            inputs_embeds = inputs_embeds[None, ...]

        if self.model_config.uses_mrope:
            position_ids = positions[:, None]
        else:
            position_ids = positions[None, ...]

        hidden_states = self.model(
            input_ids=input_ids,
            inputs_embeds=inputs_embeds,
            use_cache=False,
            position_ids=position_ids,
            attention_instances=self.attention_instances,
            return_dict=False,
            **kwargs,
        )[0][0, ...]  # we remove batch dimension for now

        if not self.pp_group.is_last_rank:
            return IntermediateTensors({"hidden_states": hidden_states})

        return hidden_states

    def load_weights(
        self,
        weights: Iterable[tuple[str, torch.Tensor]],
    ) -> set[str]:
        loader = AutoWeightsLoader(
            self,
            skip_prefixes=self.skip_prefixes,
            skip_substrs=self.skip_substrs,
            ignore_unexpected_prefixes=self.ignore_unexpected_prefixes,
            ignore_unexpected_suffixes=self.ignore_unexpected_suffixes,
        )
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)

    def check_version(self, min_version: str, feature: str):
        installed = Version(transformers.__version__)
        required = Version(min_version)
        if installed < required:
            raise ImportError(
                f"Transformers backend requires transformers>={required} "
                f"for {feature}, but got {installed}"
            )

attention_instances instance-attribute

attention_instances = create_attention_instances()

cache_config instance-attribute

cache_config: CacheConfig = cache_config

config instance-attribute

config: PretrainedConfig = hf_config

device_config instance-attribute

device_config: DeviceConfig = device_config

embed_scale instance-attribute

embed_scale = getattr(input_embeddings, "embed_scale", None)

embedding_modules class-attribute instance-attribute

embedding_modules = ['embed_tokens']

embedding_padding_modules class-attribute instance-attribute

embedding_padding_modules = ['lm_head']

ignore_unexpected_prefixes instance-attribute

ignore_unexpected_prefixes: list[str] = []

Ignore unexpected weights whose qualname starts with these prefixes.

ignore_unexpected_suffixes instance-attribute

ignore_unexpected_suffixes: list[str] = []

Ignore unexpected weights whose qualname ends with these suffixes.

make_empty_intermediate_tensors instance-attribute

make_empty_intermediate_tensors = (
    make_empty_intermediate_tensors_factory(
        ["hidden_states"], hidden_size
    )
)

model instance-attribute

model: PreTrainedModel = from_config(
    config,
    torch_dtype=dtype,
    trust_remote_code=trust_remote_code,
)

model_config instance-attribute

model_config: ModelConfig = model_config

parallel_config instance-attribute

parallel_config: ParallelConfig = parallel_config

pp_group instance-attribute

pp_group = get_pp_group()

quant_config instance-attribute

quant_config: Optional[QuantizationConfig] = quant_config

skip_prefixes instance-attribute

skip_prefixes: list[str] = []

Skip loading weights whose qualname starts with these prefixes.

skip_substrs instance-attribute

skip_substrs: list[str] = []

Skip loading weights whose qualname contains these substrings.

text_config instance-attribute

text_config: PretrainedConfig = get_text_config()

tp_group instance-attribute

tp_group = get_tp_group()

__init__

__init__(*, vllm_config: VllmConfig, prefix: str = '')
Source code in vllm/model_executor/models/transformers.py
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
    super().__init__()
    logger.info("Using Transformers backend.")

    self.config: PretrainedConfig = vllm_config.model_config.hf_config
    self.text_config: PretrainedConfig = self.config.get_text_config()
    self.cache_config: CacheConfig = vllm_config.cache_config
    self.device_config: DeviceConfig = vllm_config.device_config
    self.model_config: ModelConfig = vllm_config.model_config
    self.parallel_config: ParallelConfig = vllm_config.parallel_config
    self.quant_config: Optional[QuantizationConfig] = vllm_config.quant_config

    self.pp_group = get_pp_group()
    self.tp_group = get_tp_group()

    # Weights to skip in `self.load_weights`
    self.skip_prefixes: list[str] = []
    """Skip loading weights whose qualname starts with these prefixes."""
    self.skip_substrs: list[str] = []
    """Skip loading weights whose qualname contains these substrings."""
    self.ignore_unexpected_prefixes: list[str] = []
    """Ignore unexpected weights whose qualname starts with these prefixes.
    """
    self.ignore_unexpected_suffixes: list[str] = []
    """Ignore unexpected weights whose qualname ends with these suffixes."""

    if self.quant_config:
        quant_method_name = self.quant_config.get_name()
        # Check for unsupported quantization methods.
        if quant_method_name == "mxfp4":
            raise NotImplementedError(
                "Transformers backend does not support MXFP4 quantization yet."
            )
        # Skip loading extra bias for GPTQ models.
        if "gptq" in quant_method_name:
            self.ignore_unexpected_suffixes.append(".bias")

    # Set correct attn and init on "meta" to delay allocating GPU tensors
    self.text_config._attn_implementation = "vllm"
    with init_on_device_without_buffers("meta"):
        self.model: PreTrainedModel = AutoModel.from_config(
            self.config,
            torch_dtype=self.model_config.dtype,
            trust_remote_code=self.model_config.trust_remote_code,
        )

    # Remove layers not on this pipeline parallel rank
    self.pipeline_parallel()
    # Substitute remaining layers with vLLM's layers as needed
    self.recursive_replace()
    # Create attention instances for KV cache allocation
    self.attention_instances = self.create_attention_instances()

    # Input embeddings
    input_embeddings = self.model.get_input_embeddings()
    if not isinstance(input_embeddings, PPMissingLayer):
        # Some models use embedding scales
        self.embed_scale = getattr(input_embeddings, "embed_scale", None)
        names = ("embedding_size", "hidden_size")
        embedding_dim = getattr_iter(self.text_config, names, None)
        assert embedding_dim is not None
        self.model.set_input_embeddings(
            VocabParallelEmbedding(
                self.text_config.vocab_size,
                embedding_dim=embedding_dim,
                org_num_embeddings=self.text_config.vocab_size,
                quant_config=self.quant_config,
            )
        )

    # Initialize any parameters that have not had their modules replaced
    self.init_parameters(self.model)

    # Pipeline parallel intermediate tensors
    self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
        ["hidden_states"], self.text_config.hidden_size
    )

check_version

check_version(min_version: str, feature: str)
Source code in vllm/model_executor/models/transformers.py
def check_version(self, min_version: str, feature: str):
    installed = Version(transformers.__version__)
    required = Version(min_version)
    if installed < required:
        raise ImportError(
            f"Transformers backend requires transformers>={required} "
            f"for {feature}, but got {installed}"
        )

create_attention_instances

create_attention_instances(
    attn_type: AttentionType = DECODER,
) -> dict[int, Attention]

Create Attention instances to inform KV cache allocation.

Source code in vllm/model_executor/models/transformers.py
def create_attention_instances(
    self, attn_type: AttentionType = AttentionType.DECODER
) -> dict[int, Attention]:
    """
    Create `Attention` instances to inform KV cache allocation.
    """
    num_heads = self.model_config.get_num_attention_heads(self.parallel_config)
    head_size = self.model_config.get_head_size()
    num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config)
    logits_soft_cap = getattr(self.text_config, "attn_logit_softcapping", None)
    start, end = get_pp_indices(
        self.text_config.num_hidden_layers,
        self.pp_group.rank_in_group,
        self.pp_group.world_size,
    )

    attention_instances = {}
    for i in range(start, end):
        # Handle interleaved sliding window attention
        per_layer_sliding_window = None
        if (
            hasattr(self.config, "layer_types")
            and self.config.layer_types[i] == "sliding_attention"
        ):
            per_layer_sliding_window = self.config.sliding_window

        attention_instances[i] = Attention(
            num_heads=num_heads,
            head_size=head_size,
            # NOTE: We use Llama scale as default, if it's set by
            # Transformers, it's updated in vllm_flash_attention_forward
            scale=head_size**-0.5,
            num_kv_heads=num_kv_heads,
            cache_config=self.cache_config,
            quant_config=self.quant_config,
            logits_soft_cap=logits_soft_cap,
            per_layer_sliding_window=per_layer_sliding_window,
            prefix=f"{i}.attn",
            attn_type=attn_type,
        )
    return attention_instances

forward

forward(
    input_ids: Optional[Tensor],
    positions: Tensor,
    intermediate_tensors: Optional[
        IntermediateTensors
    ] = None,
    inputs_embeds: Optional[Tensor] = None,
    **kwargs,
) -> Union[Tensor, IntermediateTensors]
Source code in vllm/model_executor/models/transformers.py
def forward(
    self,
    input_ids: Optional[torch.Tensor],
    positions: torch.Tensor,
    intermediate_tensors: Optional[IntermediateTensors] = None,
    inputs_embeds: Optional[torch.Tensor] = None,
    **kwargs,
) -> Union[torch.Tensor, IntermediateTensors]:
    if not self.pp_group.is_first_rank:
        assert intermediate_tensors is not None
        input_ids = None
        inputs_embeds = intermediate_tensors["hidden_states"]

    if input_ids is not None:
        input_ids = input_ids[None, ...]
    if inputs_embeds is not None:
        inputs_embeds = inputs_embeds[None, ...]

    if self.model_config.uses_mrope:
        position_ids = positions[:, None]
    else:
        position_ids = positions[None, ...]

    hidden_states = self.model(
        input_ids=input_ids,
        inputs_embeds=inputs_embeds,
        use_cache=False,
        position_ids=position_ids,
        attention_instances=self.attention_instances,
        return_dict=False,
        **kwargs,
    )[0][0, ...]  # we remove batch dimension for now

    if not self.pp_group.is_last_rank:
        return IntermediateTensors({"hidden_states": hidden_states})

    return hidden_states

init_parameters

init_parameters(
    module: Module, dtype: Optional[dtype] = None
)

If a parameter is on the meta device, then its parent module is the original module created by:

with torch.device("meta"):
    self.model: PreTrainedModel = AutoModel.from_config(...)
Source code in vllm/model_executor/models/transformers.py
def init_parameters(self, module: nn.Module, dtype: Optional[torch.dtype] = None):
    """
    If a `parameter` is on the `meta` device, then its parent
    `module` is the original module created by:

    ```python
    with torch.device("meta"):
        self.model: PreTrainedModel = AutoModel.from_config(...)
    ```
    """

    def _init_parameters(module: nn.Module, dtype: Optional[torch.dtype]):
        for name, param in module.named_parameters(recurse=False):
            if param.device == torch.device("meta"):
                new_param = nn.Parameter(
                    torch.empty_like(
                        param.data,
                        dtype=dtype or self.model_config.dtype,
                        device=self.device_config.device,
                    )
                )
                setattr(module, name, new_param)
        for child in module.children():
            _init_parameters(child, dtype)

    _init_parameters(module, dtype)

load_weights

load_weights(
    weights: Iterable[tuple[str, Tensor]],
) -> set[str]
Source code in vllm/model_executor/models/transformers.py
def load_weights(
    self,
    weights: Iterable[tuple[str, torch.Tensor]],
) -> set[str]:
    loader = AutoWeightsLoader(
        self,
        skip_prefixes=self.skip_prefixes,
        skip_substrs=self.skip_substrs,
        ignore_unexpected_prefixes=self.ignore_unexpected_prefixes,
        ignore_unexpected_suffixes=self.ignore_unexpected_suffixes,
    )
    return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)

pipeline_parallel

pipeline_parallel()

Apply the model's pipeline parallelization plan.

Source code in vllm/model_executor/models/transformers.py
def pipeline_parallel(self):
    """
    Apply the model's pipeline parallelization plan.
    """
    if self.pp_group.world_size <= 1:
        return

    if not self.model.supports_pp_plan:
        tip = get_feature_request_tip(
            self.model_config.model, self.model_config.trust_remote_code
        )
        raise ValueError(
            f"{type(self.model)} does not support pipeline parallel. {tip}"
        )

    module_lists = []
    module_list_idx = None
    pp_plan = list(self.model._pp_plan.keys())
    for i, name in enumerate(pp_plan):
        if isinstance(getattr(self.model, name), nn.ModuleList):
            module_lists.append(name)
            module_list_idx = i

    if len(module_lists) > 1:
        raise ValueError(
            "Pipeline parallel of models with multiple `ModuleList`s "
            "in the base model are not supported yet!"
        )
    if module_list_idx is None:
        raise ValueError(f"Could not find `ModuleList` in {type(self.model)}")

    # Layers before module list
    for name in pp_plan[:module_list_idx]:
        if self.pp_group.is_first_rank or (
            self.text_config.tie_word_embeddings and self.pp_group.is_last_rank
        ):
            continue
        setattr(self.model, name, PPMissingLayer())

    # Module list
    start_layer, end_layer = get_pp_indices(
        self.text_config.num_hidden_layers,
        self.pp_group.rank_in_group,
        self.pp_group.world_size,
    )
    layers_name = pp_plan[module_list_idx]
    layers = getattr(self.model, layers_name)
    for i in range(len(layers)):
        if start_layer <= i and i < end_layer:
            continue
        layers[i] = PPMissingLayer()

    # Layers after module list
    for name in pp_plan[module_list_idx + 1 :]:
        # Modules that should be on last rank
        if not self.pp_group.is_last_rank:
            setattr(self.model, name, PPMissingLayer())

recursive_replace

recursive_replace()

Recursively replace modules in the model as needed.

Currently, this replaces:

  • nn.Linear with vLLM's tensor parallel linear classes
  • *RMSNorm with vLLM's RMSNorm
Source code in vllm/model_executor/models/transformers.py
def recursive_replace(self):
    """Recursively replace modules in the model as needed.

    Currently, this replaces:

    - `nn.Linear` with vLLM's tensor parallel linear classes
    - `*RMSNorm` with vLLM's `RMSNorm`
    """
    tp_plan = self.model.tp_plan

    if not tp_plan and self.tp_group.world_size > 1:
        tip = get_feature_request_tip(
            self.model_config.model, self.model_config.trust_remote_code
        )
        raise ValueError(
            f"{type(self.model)} does not support tensor parallel. {tip}"
        )

    # Prefix the patterns because we always start from `self.model`
    tp_plan = {maybe_prefix("model", k): v for k, v in tp_plan.items()}

    def _recursive_replace(module: nn.Module, prefix: str):
        for child_name, child_module in module.named_children():
            new_module = child_module
            qual_name = maybe_prefix(prefix, child_name)
            if isinstance(child_module, nn.Linear):
                generator = (p for p in tp_plan if re.match(p, qual_name))
                pattern = next(generator, None)
                # Some weight loaders expect all linear layers to inherit
                # LinearBase, so we set a default style which causes any
                # unspecified layers to be replaced with ReplicatedLinear
                style = tp_plan.get(pattern, "replicate")
                new_module = replace_linear_class(
                    child_module, style, self.quant_config, prefix=qual_name
                )
            # TODO(hmellor): Enable RMSNorm replacement once we have a way
            # to choose RMSNorm vs GemmaRMSNorm
            # elif child_module.__class__.__name__.endswith("RMSNorm"):
            #     new_module = replace_rms_norm_class(
            #         child_module, self.config.hidden_size)
            else:
                _recursive_replace(child_module, prefix=qual_name)

            if new_module is not child_module:
                setattr(module, child_name, new_module)
                log_replacement(qual_name, child_module, new_module)

    _recursive_replace(self.model, prefix="model")

TransformersForCausalLM

Bases: TransformersBase

Source code in vllm/model_executor/models/transformers.py
@support_torch_compile(enable_if=can_enable_torch_compile)
class TransformersForCausalLM(TransformersBase):
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__(vllm_config=vllm_config, prefix=prefix)

        # Tell `TransformersBase.load_weights` to skip
        # `lm_head` if the model has tied word embeddings
        if self.text_config.tie_word_embeddings:
            self.skip_prefixes.append("lm_head.")

        if self.pp_group.is_last_rank:
            self.unpadded_vocab_size = self.text_config.vocab_size
            self.lm_head = ParallelLMHead(
                self.text_config.vocab_size,
                self.text_config.hidden_size,
                quant_config=self.quant_config,
                prefix=maybe_prefix(prefix, "lm_head"),
            )
            if self.text_config.tie_word_embeddings:
                self.lm_head = self.lm_head.tie_weights(
                    self.model.get_input_embeddings()
                )

            logit_scale = getattr(self.text_config, "logit_scale", 1.0)
            self.logits_processor = LogitsProcessor(
                self.unpadded_vocab_size, self.text_config.vocab_size, logit_scale
            )
        else:
            self.lm_head = PPMissingLayer()

    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        inputs_embeds = self.model.get_input_embeddings()(input_ids)
        if self.embed_scale is not None:
            inputs_embeds *= self.embed_scale
        return inputs_embeds

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
    ) -> Optional[torch.Tensor]:
        logits = self.logits_processor(self.lm_head, hidden_states)
        return logits

lm_head instance-attribute

lm_head = ParallelLMHead(
    vocab_size,
    hidden_size,
    quant_config=quant_config,
    prefix=maybe_prefix(prefix, "lm_head"),
)

logits_processor instance-attribute

logits_processor = LogitsProcessor(
    unpadded_vocab_size, vocab_size, logit_scale
)

unpadded_vocab_size instance-attribute

unpadded_vocab_size = vocab_size

__init__

__init__(*, vllm_config: VllmConfig, prefix: str = '')
Source code in vllm/model_executor/models/transformers.py
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
    super().__init__(vllm_config=vllm_config, prefix=prefix)

    # Tell `TransformersBase.load_weights` to skip
    # `lm_head` if the model has tied word embeddings
    if self.text_config.tie_word_embeddings:
        self.skip_prefixes.append("lm_head.")

    if self.pp_group.is_last_rank:
        self.unpadded_vocab_size = self.text_config.vocab_size
        self.lm_head = ParallelLMHead(
            self.text_config.vocab_size,
            self.text_config.hidden_size,
            quant_config=self.quant_config,
            prefix=maybe_prefix(prefix, "lm_head"),
        )
        if self.text_config.tie_word_embeddings:
            self.lm_head = self.lm_head.tie_weights(
                self.model.get_input_embeddings()
            )

        logit_scale = getattr(self.text_config, "logit_scale", 1.0)
        self.logits_processor = LogitsProcessor(
            self.unpadded_vocab_size, self.text_config.vocab_size, logit_scale
        )
    else:
        self.lm_head = PPMissingLayer()

compute_logits

compute_logits(hidden_states: Tensor) -> Optional[Tensor]
Source code in vllm/model_executor/models/transformers.py
def compute_logits(
    self,
    hidden_states: torch.Tensor,
) -> Optional[torch.Tensor]:
    logits = self.logits_processor(self.lm_head, hidden_states)
    return logits

get_input_embeddings

get_input_embeddings(input_ids: Tensor) -> Tensor
Source code in vllm/model_executor/models/transformers.py
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
    inputs_embeds = self.model.get_input_embeddings()(input_ids)
    if self.embed_scale is not None:
        inputs_embeds *= self.embed_scale
    return inputs_embeds

TransformersForMultimodalLM

Bases: TransformersForCausalLM, SupportsMultiModal

Source code in vllm/model_executor/models/transformers.py
@MULTIMODAL_REGISTRY.register_processor(
    MultiModalProcessor,
    info=MultiModalProcessingInfo,
    dummy_inputs=MultiModalDummyInputsBuilder,
)
@support_torch_compile(
    # set `positions` to last dim to support Qwen-mrope
    dynamic_arg_dims={
        "input_ids": 0,
        "positions": -1,
        "intermediate_tensors": 0,
        "inputs_embeds": 0,
    },
    enable_if=can_enable_torch_compile,
)
class TransformersForMultimodalLM(TransformersForCausalLM, SupportsMultiModal):
    supports_multimodal_raw_input_only = True
    merge_by_field_config = True
    # Backwards compatibility for prev released models. State dicts back then
    # had different formats and cannot be loaded with `AutoModel` mapping as is
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
            "language_model.model": "model.language_model",
            "text_model.model": "model.text_model",
            "vision_tower": "model.vision_tower",
            "vqmodel": "model.vqmodel",
            "visual": "model.visual",
            "vision_model": "model.vision_model",
            "vision_embed_tokens": "model.vision_embed_tokens",
            "image_newline": "model.image_newline",
            "multi_modal_projector": "model.multi_modal_projector",
            "text_model.lm_head": "lm_head",
            "language_model.lm_head": "lm_head",
            # Qwen models used "model" as the name for the language model.
            # Therefore, we must map each of submodule explicitly to avoid
            # conflicts with newer models that use "model.language_model".
            "model.embed_tokens": "model.language_model.embed_tokens",
            "model.layers": "model.language_model.layers",
            "model.norm": "model.language_model.norm",
        }
    )

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__(vllm_config=vllm_config, prefix=prefix)

        self.dtype = vllm_config.model_config.dtype

    def forward(
        self,
        input_ids: Optional[torch.Tensor],
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        **kwargs: object,
    ) -> Union[torch.Tensor, IntermediateTensors]:
        # Gemma3 and PaliGemma needs `token_type_ids` to work correctly
        # Other models will not have `token_type_ids` in kwargs
        kwargs = {k: v for k, v in kwargs.items() if k == "token_type_ids"}
        model_output = super().forward(
            input_ids, positions, intermediate_tensors, inputs_embeds, **kwargs
        )
        return model_output

    def get_language_model(self) -> torch.nn.Module:
        """`TransformersForMultimodalLM` does not contain a vLLM language model class.
        Therefore, in order to return a language model vLLM class, we use a wrapper to
        give `self` the same interface as `TransformersForCausalLM`."""

        class LanguageModelWrapper(TransformersForCausalLM):
            def __init__(self, multimodal_model):
                # Don't call super().__init__() to avoid re-initialization
                self.__dict__.update(multimodal_model.__dict__)

            model = getattr_iter(self.model, ("language_model", "text_model"), None)

        return LanguageModelWrapper(self)

    def get_multimodal_embeddings(self, **kwargs):
        pixel_values: Optional[torch.Tensor] = kwargs.pop("pixel_values", None)
        image_embeds: Optional[torch.Tensor] = kwargs.pop("image_embeds", None)
        # Model might use `image_patches` instead of `pixel_values`
        if pixel_values is None:
            pixel_values = kwargs.pop("image_patches", None)

        if image_embeds is not None:
            return image_embeds

        if pixel_values is None:
            return None

        num_image_patches = kwargs.pop("num_image_patches")
        kwargs.pop("token_type_ids", None)  # used only in `forward`
        if pixel_values is not None:
            vision_embeddings = self.model.get_image_features(pixel_values, **kwargs)

            if isinstance(vision_embeddings, torch.Tensor):
                if vision_embeddings.ndim == 2:
                    vision_embeddings = vision_embeddings.unsqueeze(0)

                # Embeddings have to be 2D tensors of length `num_images`
                # but transformers returns concat tensors if each patch
                # is of different size. We split it back to make vLLM happy
                vision_embeddings = torch.split(
                    vision_embeddings, num_image_patches.flatten().tolist()
                )
                vision_embeddings = [
                    embed.flatten(start_dim=0, end_dim=-2)
                    for embed in vision_embeddings
                ]

            return vision_embeddings

    get_input_embeddings = SupportsMultiModal.get_input_embeddings

dtype instance-attribute

dtype = dtype

get_input_embeddings class-attribute instance-attribute

get_input_embeddings = get_input_embeddings

hf_to_vllm_mapper class-attribute instance-attribute

hf_to_vllm_mapper = WeightsMapper(
    orig_to_new_prefix={
        "language_model.model": "model.language_model",
        "text_model.model": "model.text_model",
        "vision_tower": "model.vision_tower",
        "vqmodel": "model.vqmodel",
        "visual": "model.visual",
        "vision_model": "model.vision_model",
        "vision_embed_tokens": "model.vision_embed_tokens",
        "image_newline": "model.image_newline",
        "multi_modal_projector": "model.multi_modal_projector",
        "text_model.lm_head": "lm_head",
        "language_model.lm_head": "lm_head",
        "model.embed_tokens": "model.language_model.embed_tokens",
        "model.layers": "model.language_model.layers",
        "model.norm": "model.language_model.norm",
    }
)

merge_by_field_config class-attribute instance-attribute

merge_by_field_config = True

supports_multimodal_raw_input_only class-attribute instance-attribute

supports_multimodal_raw_input_only = True

__init__

__init__(*, vllm_config: VllmConfig, prefix: str = '')
Source code in vllm/model_executor/models/transformers.py
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
    super().__init__(vllm_config=vllm_config, prefix=prefix)

    self.dtype = vllm_config.model_config.dtype

forward

forward(
    input_ids: Optional[Tensor],
    positions: Tensor,
    intermediate_tensors: Optional[
        IntermediateTensors
    ] = None,
    inputs_embeds: Optional[Tensor] = None,
    **kwargs: object,
) -> Union[Tensor, IntermediateTensors]
Source code in vllm/model_executor/models/transformers.py
def forward(
    self,
    input_ids: Optional[torch.Tensor],
    positions: torch.Tensor,
    intermediate_tensors: Optional[IntermediateTensors] = None,
    inputs_embeds: Optional[torch.Tensor] = None,
    **kwargs: object,
) -> Union[torch.Tensor, IntermediateTensors]:
    # Gemma3 and PaliGemma needs `token_type_ids` to work correctly
    # Other models will not have `token_type_ids` in kwargs
    kwargs = {k: v for k, v in kwargs.items() if k == "token_type_ids"}
    model_output = super().forward(
        input_ids, positions, intermediate_tensors, inputs_embeds, **kwargs
    )
    return model_output

get_language_model

get_language_model() -> Module

TransformersForMultimodalLM does not contain a vLLM language model class. Therefore, in order to return a language model vLLM class, we use a wrapper to give self the same interface as TransformersForCausalLM.

Source code in vllm/model_executor/models/transformers.py
def get_language_model(self) -> torch.nn.Module:
    """`TransformersForMultimodalLM` does not contain a vLLM language model class.
    Therefore, in order to return a language model vLLM class, we use a wrapper to
    give `self` the same interface as `TransformersForCausalLM`."""

    class LanguageModelWrapper(TransformersForCausalLM):
        def __init__(self, multimodal_model):
            # Don't call super().__init__() to avoid re-initialization
            self.__dict__.update(multimodal_model.__dict__)

        model = getattr_iter(self.model, ("language_model", "text_model"), None)

    return LanguageModelWrapper(self)

get_multimodal_embeddings

get_multimodal_embeddings(**kwargs)
Source code in vllm/model_executor/models/transformers.py
def get_multimodal_embeddings(self, **kwargs):
    pixel_values: Optional[torch.Tensor] = kwargs.pop("pixel_values", None)
    image_embeds: Optional[torch.Tensor] = kwargs.pop("image_embeds", None)
    # Model might use `image_patches` instead of `pixel_values`
    if pixel_values is None:
        pixel_values = kwargs.pop("image_patches", None)

    if image_embeds is not None:
        return image_embeds

    if pixel_values is None:
        return None

    num_image_patches = kwargs.pop("num_image_patches")
    kwargs.pop("token_type_ids", None)  # used only in `forward`
    if pixel_values is not None:
        vision_embeddings = self.model.get_image_features(pixel_values, **kwargs)

        if isinstance(vision_embeddings, torch.Tensor):
            if vision_embeddings.ndim == 2:
                vision_embeddings = vision_embeddings.unsqueeze(0)

            # Embeddings have to be 2D tensors of length `num_images`
            # but transformers returns concat tensors if each patch
            # is of different size. We split it back to make vLLM happy
            vision_embeddings = torch.split(
                vision_embeddings, num_image_patches.flatten().tolist()
            )
            vision_embeddings = [
                embed.flatten(start_dim=0, end_dim=-2)
                for embed in vision_embeddings
            ]

        return vision_embeddings

can_enable_torch_compile

can_enable_torch_compile(vllm_config: VllmConfig) -> bool

Callable to be passed to @support_torch_compile's enable_if argument.

Defaults to True but is disabled in the following situations:

  • The model uses dynamic rope scaling.
Source code in vllm/model_executor/models/transformers.py
def can_enable_torch_compile(vllm_config: VllmConfig) -> bool:
    """
    Callable to be passed to `@support_torch_compile`'s `enable_if` argument.

    Defaults to `True` but is disabled in the following situations:

    - The model uses dynamic rope scaling.
    """
    enable = True
    text_config = vllm_config.model_config.hf_config.get_text_config()
    # Dynamic rope scaling is not compatible with torch.compile
    rope_scaling: dict = getattr(text_config, "rope_scaling", None) or {}
    if rope_scaling.get("rope_type") == "dynamic":
        enable = False
    return enable

get_feature_request_tip

get_feature_request_tip(
    model: str, trust_remote_code: bool
) -> str
Source code in vllm/model_executor/models/transformers.py
def get_feature_request_tip(
    model: str,
    trust_remote_code: bool,
) -> str:
    hf_url = f"a discussion at https://huggingface.co/{model}/discussions/new"
    gh_url = "an issue at https://github.com/huggingface/transformers/issues/new/choose"
    url = hf_url if trust_remote_code else gh_url
    prefix = f"Please open {url} to request support for this feature. "
    if Path(model).exists():
        prefix = ""
    doc_url = "https://docs.vllm.ai/en/latest/models/supported_models.html#writing-custom-models"
    tip = f"See {doc_url} for instructions on how to add support yourself."
    return f"{prefix}{tip}"

init_on_device_without_buffers

init_on_device_without_buffers(device: device)

A context manager under which models are initialized with all parameters on the specified device. However buffers are not initialized on specified device.

Parameters:

Name Type Description Default
device `torch.device`

Device to initialize all parameters on.

required
Source code in vllm/model_executor/models/transformers.py
@contextmanager
def init_on_device_without_buffers(device: torch.device):
    """
    A context manager under which models are initialized with all
    parameters on the specified device. However buffers are not
    initialized on specified device.

    Args:
        device (`torch.device`):
            Device to initialize all parameters on.
    """

    old_register_parameter = nn.Module.register_parameter

    def register_empty_parameter(module, name, param):
        old_register_parameter(module, name, param)
        if param is not None:
            param_cls = type(module._parameters[name])
            kwargs = module._parameters[name].__dict__
            kwargs["requires_grad"] = param.requires_grad
            module._parameters[name] = param_cls(
                module._parameters[name].to(device), **kwargs
            )

    tensor_constructors_to_patch = {}

    def patch_tensor_constructor(fn):
        def wrapper(*args, **kwargs):
            kwargs["device"] = device
            return fn(*args, **kwargs)

        return wrapper

    try:
        nn.Module.register_parameter = register_empty_parameter
        for torch_function_name in tensor_constructors_to_patch:
            setattr(
                torch,
                torch_function_name,
                patch_tensor_constructor(getattr(torch, torch_function_name)),
            )
        yield
    finally:
        nn.Module.register_parameter = old_register_parameter
        for (
            torch_function_name,
            old_torch_function,
        ) in tensor_constructors_to_patch.items():
            setattr(torch, torch_function_name, old_torch_function)

log_replacement

log_replacement(
    name: str, old_module: Module, new_module: Module
)
Source code in vllm/model_executor/models/transformers.py
def log_replacement(name: str, old_module: nn.Module, new_module: nn.Module):
    logger.debug("%s: %s -> %s", name, old_module, new_module)

replace_linear_class

replace_linear_class(
    linear: Linear,
    style: Style = "replicate",
    quant_config: Optional[QuantizationConfig] = None,
    *,
    prefix: str = "",
) -> Union[
    ColumnParallelLinear,
    RowParallelLinear,
    ReplicatedLinear,
]

Replace nn.Linear with one of vLLM's tensor parallel linear classes.

Parameters:

Name Type Description Default
linear Linear

nn.Linear to be replaced.

required
style Style

Tensor parallel style of the new linear, e.g. "colwise".

'replicate'
quant_config Optional[QuantizationConfig]

Quantization config for the new linear.

None

Returns: The new linear.

Source code in vllm/model_executor/models/transformers.py
def replace_linear_class(
    linear: nn.Linear,
    style: Style = "replicate",
    quant_config: Optional[QuantizationConfig] = None,
    *,
    prefix: str = "",
) -> Union[ColumnParallelLinear, RowParallelLinear, ReplicatedLinear]:
    """
    Replace nn.Linear with one of vLLM's tensor parallel linear classes.

    Args:
        linear: `nn.Linear` to be replaced.
        style: Tensor parallel style of the new linear, e.g. "colwise".
        quant_config: Quantization config for the new linear.
    Returns:
        The new linear.
    """

    if not isinstance(style, str):
        raise ValueError(f"Unsupported parallel style type {type(style)}, expected str")

    vllm_linear_cls, vllm_linear_kwargs = {
        "colwise": (ColumnParallelLinear, {}),
        "colwise_rep": (ColumnParallelLinear, {"gather_output": True}),
        "rowwise": (RowParallelLinear, {}),
        "rowwise_rep": (RowParallelLinear, {"input_is_parallel": False}),
        "replicate": (ReplicatedLinear, {}),
    }.get(style, (ReplicatedLinear, {}))

    return vllm_linear_cls(
        input_size=linear.in_features,
        output_size=linear.out_features,
        bias=linear.bias is not None,
        quant_config=quant_config,
        prefix=prefix,
        return_bias=False,
        **vllm_linear_kwargs,
    )

replace_rms_norm_class

replace_rms_norm_class(
    rms_norm: Module, hidden_size: int
) -> RMSNorm

Replace a Transformers RMSNorm with vLLM's RMSNorm.

This method assumes: - Weight is stored as weight. - Epsilon is stored as eps or variance_epsilon. - with_scale indicates whether the layer has a weight (Gemma3n only). - var_hidden_size is only ever used for Intern vision encoder in vLLM and Transformers doesn't appear to have the same concept.

Source code in vllm/model_executor/models/transformers.py
def replace_rms_norm_class(rms_norm: nn.Module, hidden_size: int) -> RMSNorm:
    """Replace a Transformers RMSNorm with vLLM's RMSNorm.

    This method assumes:
    - Weight is stored as `weight`.
    - Epsilon is stored as `eps` or `variance_epsilon`.
    - `with_scale` indicates whether the layer has a weight (Gemma3n only).
    - `var_hidden_size` is only ever used for Intern vision encoder in vLLM
    and Transformers doesn't appear to have the same concept.
    """
    kwargs = {
        "hidden_size": hidden_size,
        "eps": getattr_iter(rms_norm, ("eps", "variance_epsilon"), 1e-6),
        "has_weight": getattr(rms_norm, "with_scale", True),
    }
    if (weight := getattr(rms_norm, "weight", None)) is not None:
        # If weight is a Parameter, get its data tensor
        weight = getattr(weight, "data", weight)
        kwargs["dtype"] = weight.dtype
    else:
        # No weight, fall back to weightless RMSNorm
        kwargs["has_weight"] = False
    return RMSNorm(**kwargs)

vllm_flash_attention_forward

vllm_flash_attention_forward(
    module: Module,
    query: Tensor,
    key: Tensor,
    value: Tensor,
    attention_mask: Tensor,
    scaling: Optional[float] = None,
    attention_instances: Optional[dict[Attention]] = None,
    **kwargs,
)
Source code in vllm/model_executor/models/transformers.py
def vllm_flash_attention_forward(
    # Transformers args
    module: torch.nn.Module,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    attention_mask: torch.Tensor,
    # Transformers kwargs
    scaling: Optional[float] = None,
    # vLLM kwargs
    attention_instances: Optional[dict[Attention]] = None,
    **kwargs,
):
    self_attn = attention_instances[module.layer_idx]
    if scaling is not None:
        self_attn.impl.scale = float(scaling)
    hidden = query.shape[-2]
    query, key, value = (x.transpose(1, 2) for x in (query, key, value))
    query, key, value = (x.reshape(hidden, -1) for x in (query, key, value))
    return self_attn.forward(query, key, value), None