Skip to content

vllm.v1.worker.lora_model_runner_mixin

Define LoRA functionality mixin for model runners.

InputBatch module-attribute

InputBatch = Union[InputBatch, InputBatch]

logger module-attribute

logger = init_logger(__name__)

LoRAModelRunnerMixin

Source code in vllm/v1/worker/lora_model_runner_mixin.py
class LoRAModelRunnerMixin:
    LORA_WARMUP_RANK = 8

    def load_lora_model(
        self, model: nn.Module, vllm_config: VllmConfig, device: torch.device
    ) -> nn.Module:
        if not supports_lora(model):
            raise ValueError(f"{model.__class__.__name__} does not support LoRA yet.")

        if supports_multimodal(model):
            logger.warning(
                "Regarding multimodal models, vLLM currently "
                "only supports adding LoRA to language model."
            )

        # Add LoRA Manager to the Model Runner
        self.lora_manager = LRUCacheWorkerLoRAManager(
            vllm_config,
            device,
            model.embedding_modules,
            model.embedding_padding_modules,
        )
        return self.lora_manager.create_lora_manager(model)

    def _set_active_loras(
        self,
        prompt_lora_mapping: tuple[int, ...],
        token_lora_mapping: tuple[int, ...],
        lora_requests: set[LoRARequest],
    ) -> None:
        self._ensure_lora_enabled()

        # Set is_prefill to True, so we always use the SGMV kernels on
        # non-cuda platforms.
        # On cuda platforms we use the same kernels for prefill and
        # decode and this flag is generally ignored.
        lora_mapping = LoRAMapping(
            token_lora_mapping, prompt_lora_mapping, is_prefill=True
        )
        self.lora_manager.set_active_adapters(lora_requests, lora_mapping)

    def _ensure_lora_enabled(self) -> None:
        if not hasattr(self, "lora_manager"):
            raise RuntimeError("LoRA is not enabled. Use --enable-lora to enable LoRA.")

    def set_active_loras(
        self, input_batch: InputBatch, num_scheduled_tokens: np.ndarray
    ) -> None:
        prompt_lora_mapping: tuple[int, ...]  # of size input_batch.num_reqs
        token_lora_mapping: tuple[int, ...]  # of size np.sum(num_scheduled_tokens)
        lora_requests: set[LoRARequest]
        prompt_lora_mapping, token_lora_mapping, lora_requests = (
            input_batch.make_lora_inputs(num_scheduled_tokens)
        )
        return self._set_active_loras(
            prompt_lora_mapping, token_lora_mapping, lora_requests
        )

    @contextmanager
    def maybe_setup_dummy_loras(
        self, lora_config: Optional[LoRAConfig], remove_lora: bool = True
    ):
        if lora_config is None:
            yield
        else:
            # __enter__ code
            assert self.lora_manager is not None, "LoRA is not enabled"

            num_loras = lora_config.max_loras

            # Make dummy lora requests
            lora_requests: set[LoRARequest] = {
                LoRARequest(
                    lora_name=f"warmup_{lora_id}",
                    lora_int_id=lora_id,
                    lora_path="/not/a/real/path",
                )
                for lora_id in range(1, num_loras + 1)
            }

            with self.lora_manager.dummy_lora_cache():
                # Add the dummy LoRAs here so _set_active_loras doesn't try to
                # load from disk.
                for lr in lora_requests:
                    self.lora_manager.add_dummy_lora(lr, rank=self.LORA_WARMUP_RANK)

                yield

            # __exit__ code
            if remove_lora:
                self.lora_manager.remove_all_adapters()

    @contextmanager
    def maybe_select_dummy_loras(
        self, lora_config: Optional[LoRAConfig], num_scheduled_tokens: np.ndarray
    ):
        if lora_config is None:
            yield
        else:
            # __enter__ code
            assert self.lora_manager is not None, "LoRA is not enabled"

            num_reqs = len(num_scheduled_tokens)
            num_loras = lora_config.max_loras

            # Make prompt lora mapping
            # Assign LoRA IDs cyclically to simulate a worst-case scenario.
            prompt_lora_mapping = (np.arange(num_reqs, dtype=np.int32) % num_loras) + 1

            # Make token lora mapping
            token_lora_mapping = np.repeat(prompt_lora_mapping, num_scheduled_tokens)

            # Make dummy lora requests
            lora_requests: set[LoRARequest] = {
                LoRARequest(
                    lora_name=f"warmup_{lora_id}",
                    lora_int_id=lora_id,
                    lora_path="/not/a/real/path",
                )
                for lora_id in range(1, num_loras + 1)
            }

            self._set_active_loras(
                tuple(prompt_lora_mapping), tuple(token_lora_mapping), lora_requests
            )

            yield

    @contextmanager
    def maybe_dummy_run_with_lora(
        self,
        lora_config: Optional[LoRAConfig],
        num_scheduled_tokens: np.ndarray,
        remove_lora: bool = True,
    ):
        with (
            self.maybe_setup_dummy_loras(lora_config, remove_lora),
            self.maybe_select_dummy_loras(lora_config, num_scheduled_tokens),
        ):
            yield

    def maybe_remove_all_loras(self, lora_config: Optional[LoRAConfig]):
        if lora_config is None:
            return
        self.lora_manager.remove_all_adapters()

    def add_lora(self, lora_request: LoRARequest) -> bool:
        self._ensure_lora_enabled()
        return self.lora_manager.add_adapter(lora_request)

    def remove_lora(self, lora_id: int) -> bool:
        self._ensure_lora_enabled()
        return self.lora_manager.remove_adapter(lora_id)

    def pin_lora(self, lora_id: int) -> bool:
        self._ensure_lora_enabled()
        return self.lora_manager.pin_adapter(lora_id)

    def list_loras(self) -> set[int]:
        self._ensure_lora_enabled()
        return self.lora_manager.list_adapters()

LORA_WARMUP_RANK class-attribute instance-attribute

LORA_WARMUP_RANK = 8

_ensure_lora_enabled

_ensure_lora_enabled() -> None
Source code in vllm/v1/worker/lora_model_runner_mixin.py
def _ensure_lora_enabled(self) -> None:
    if not hasattr(self, "lora_manager"):
        raise RuntimeError("LoRA is not enabled. Use --enable-lora to enable LoRA.")

_set_active_loras

_set_active_loras(
    prompt_lora_mapping: tuple[int, ...],
    token_lora_mapping: tuple[int, ...],
    lora_requests: set[LoRARequest],
) -> None
Source code in vllm/v1/worker/lora_model_runner_mixin.py
def _set_active_loras(
    self,
    prompt_lora_mapping: tuple[int, ...],
    token_lora_mapping: tuple[int, ...],
    lora_requests: set[LoRARequest],
) -> None:
    self._ensure_lora_enabled()

    # Set is_prefill to True, so we always use the SGMV kernels on
    # non-cuda platforms.
    # On cuda platforms we use the same kernels for prefill and
    # decode and this flag is generally ignored.
    lora_mapping = LoRAMapping(
        token_lora_mapping, prompt_lora_mapping, is_prefill=True
    )
    self.lora_manager.set_active_adapters(lora_requests, lora_mapping)

add_lora

add_lora(lora_request: LoRARequest) -> bool
Source code in vllm/v1/worker/lora_model_runner_mixin.py
def add_lora(self, lora_request: LoRARequest) -> bool:
    self._ensure_lora_enabled()
    return self.lora_manager.add_adapter(lora_request)

list_loras

list_loras() -> set[int]
Source code in vllm/v1/worker/lora_model_runner_mixin.py
def list_loras(self) -> set[int]:
    self._ensure_lora_enabled()
    return self.lora_manager.list_adapters()

load_lora_model

load_lora_model(
    model: Module, vllm_config: VllmConfig, device: device
) -> Module
Source code in vllm/v1/worker/lora_model_runner_mixin.py
def load_lora_model(
    self, model: nn.Module, vllm_config: VllmConfig, device: torch.device
) -> nn.Module:
    if not supports_lora(model):
        raise ValueError(f"{model.__class__.__name__} does not support LoRA yet.")

    if supports_multimodal(model):
        logger.warning(
            "Regarding multimodal models, vLLM currently "
            "only supports adding LoRA to language model."
        )

    # Add LoRA Manager to the Model Runner
    self.lora_manager = LRUCacheWorkerLoRAManager(
        vllm_config,
        device,
        model.embedding_modules,
        model.embedding_padding_modules,
    )
    return self.lora_manager.create_lora_manager(model)

maybe_dummy_run_with_lora

maybe_dummy_run_with_lora(
    lora_config: Optional[LoRAConfig],
    num_scheduled_tokens: ndarray,
    remove_lora: bool = True,
)
Source code in vllm/v1/worker/lora_model_runner_mixin.py
@contextmanager
def maybe_dummy_run_with_lora(
    self,
    lora_config: Optional[LoRAConfig],
    num_scheduled_tokens: np.ndarray,
    remove_lora: bool = True,
):
    with (
        self.maybe_setup_dummy_loras(lora_config, remove_lora),
        self.maybe_select_dummy_loras(lora_config, num_scheduled_tokens),
    ):
        yield

maybe_remove_all_loras

maybe_remove_all_loras(lora_config: Optional[LoRAConfig])
Source code in vllm/v1/worker/lora_model_runner_mixin.py
def maybe_remove_all_loras(self, lora_config: Optional[LoRAConfig]):
    if lora_config is None:
        return
    self.lora_manager.remove_all_adapters()

maybe_select_dummy_loras

maybe_select_dummy_loras(
    lora_config: Optional[LoRAConfig],
    num_scheduled_tokens: ndarray,
)
Source code in vllm/v1/worker/lora_model_runner_mixin.py
@contextmanager
def maybe_select_dummy_loras(
    self, lora_config: Optional[LoRAConfig], num_scheduled_tokens: np.ndarray
):
    if lora_config is None:
        yield
    else:
        # __enter__ code
        assert self.lora_manager is not None, "LoRA is not enabled"

        num_reqs = len(num_scheduled_tokens)
        num_loras = lora_config.max_loras

        # Make prompt lora mapping
        # Assign LoRA IDs cyclically to simulate a worst-case scenario.
        prompt_lora_mapping = (np.arange(num_reqs, dtype=np.int32) % num_loras) + 1

        # Make token lora mapping
        token_lora_mapping = np.repeat(prompt_lora_mapping, num_scheduled_tokens)

        # Make dummy lora requests
        lora_requests: set[LoRARequest] = {
            LoRARequest(
                lora_name=f"warmup_{lora_id}",
                lora_int_id=lora_id,
                lora_path="/not/a/real/path",
            )
            for lora_id in range(1, num_loras + 1)
        }

        self._set_active_loras(
            tuple(prompt_lora_mapping), tuple(token_lora_mapping), lora_requests
        )

        yield

maybe_setup_dummy_loras

maybe_setup_dummy_loras(
    lora_config: Optional[LoRAConfig],
    remove_lora: bool = True,
)
Source code in vllm/v1/worker/lora_model_runner_mixin.py
@contextmanager
def maybe_setup_dummy_loras(
    self, lora_config: Optional[LoRAConfig], remove_lora: bool = True
):
    if lora_config is None:
        yield
    else:
        # __enter__ code
        assert self.lora_manager is not None, "LoRA is not enabled"

        num_loras = lora_config.max_loras

        # Make dummy lora requests
        lora_requests: set[LoRARequest] = {
            LoRARequest(
                lora_name=f"warmup_{lora_id}",
                lora_int_id=lora_id,
                lora_path="/not/a/real/path",
            )
            for lora_id in range(1, num_loras + 1)
        }

        with self.lora_manager.dummy_lora_cache():
            # Add the dummy LoRAs here so _set_active_loras doesn't try to
            # load from disk.
            for lr in lora_requests:
                self.lora_manager.add_dummy_lora(lr, rank=self.LORA_WARMUP_RANK)

            yield

        # __exit__ code
        if remove_lora:
            self.lora_manager.remove_all_adapters()

pin_lora

pin_lora(lora_id: int) -> bool
Source code in vllm/v1/worker/lora_model_runner_mixin.py
def pin_lora(self, lora_id: int) -> bool:
    self._ensure_lora_enabled()
    return self.lora_manager.pin_adapter(lora_id)

remove_lora

remove_lora(lora_id: int) -> bool
Source code in vllm/v1/worker/lora_model_runner_mixin.py
def remove_lora(self, lora_id: int) -> bool:
    self._ensure_lora_enabled()
    return self.lora_manager.remove_adapter(lora_id)

set_active_loras

set_active_loras(
    input_batch: InputBatch, num_scheduled_tokens: ndarray
) -> None
Source code in vllm/v1/worker/lora_model_runner_mixin.py
def set_active_loras(
    self, input_batch: InputBatch, num_scheduled_tokens: np.ndarray
) -> None:
    prompt_lora_mapping: tuple[int, ...]  # of size input_batch.num_reqs
    token_lora_mapping: tuple[int, ...]  # of size np.sum(num_scheduled_tokens)
    lora_requests: set[LoRARequest]
    prompt_lora_mapping, token_lora_mapping, lora_requests = (
        input_batch.make_lora_inputs(num_scheduled_tokens)
    )
    return self._set_active_loras(
        prompt_lora_mapping, token_lora_mapping, lora_requests
    )