Skip to content

vllm.v1.worker.kv_connector_model_runner_mixin

Define KV connector functionality mixin for model runners.

logger module-attribute

logger = init_logger(__name__)

KVConnectorModelRunnerMixin

Source code in vllm/v1/worker/kv_connector_model_runner_mixin.py
class KVConnectorModelRunnerMixin:
    @staticmethod
    def maybe_setup_kv_connector(scheduler_output: "SchedulerOutput"):
        # Update KVConnector with the KVConnector metadata forward().
        if has_kv_transfer_group():
            kv_connector = get_kv_transfer_group()
            assert isinstance(kv_connector, KVConnectorBase)
            assert scheduler_output.kv_connector_metadata is not None
            kv_connector.bind_connector_metadata(scheduler_output.kv_connector_metadata)

            # Background KV cache transfers happen here.
            # These transfers are designed to be async and the requests
            # involved may be disjoint from the running requests.
            # Do this here to save a collective_rpc.
            kv_connector.start_load_kv(get_forward_context())

    @staticmethod
    def ensure_kv_transfer_shutdown() -> None:
        # has_kv_transfer_group can be None during interpreter shutdown.
        if has_kv_transfer_group and has_kv_transfer_group():
            ensure_kv_transfer_shutdown()

    @staticmethod
    def maybe_wait_for_kv_save() -> None:
        if has_kv_transfer_group():
            get_kv_transfer_group().wait_for_save()

    @staticmethod
    def get_finished_kv_transfers(
        scheduler_output: "SchedulerOutput",
    ) -> tuple[Optional[set[str]], Optional[set[str]]]:
        if has_kv_transfer_group():
            return get_kv_transfer_group().get_finished(
                scheduler_output.finished_req_ids
            )
        return None, None

    @staticmethod
    def kv_connector_no_forward(
        scheduler_output: "SchedulerOutput", vllm_config: VllmConfig
    ) -> ModelRunnerOutput:
        # KV send/recv even if no work to do.
        with (
            set_forward_context(None, vllm_config),
            KVConnectorModelRunnerMixin._get_kv_connector_output(
                scheduler_output, wait_for_save=False
            ) as kv_connector_output,
        ):
            pass

        if kv_connector_output.is_empty():
            return EMPTY_MODEL_RUNNER_OUTPUT

        output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT)
        output.kv_connector_output = kv_connector_output
        return output

    @staticmethod
    def maybe_get_kv_connector_output(
        scheduler_output: "SchedulerOutput",
    ) -> AbstractContextManager[Optional[KVConnectorOutput]]:
        return (
            KVConnectorModelRunnerMixin._get_kv_connector_output(scheduler_output)
            if has_kv_transfer_group()
            else nullcontext()
        )

    # This context manager must be used within an active forward context.
    # It encapsulates the entire KV connector lifecycle within execute_model
    @staticmethod
    @contextmanager
    def _get_kv_connector_output(
        scheduler_output: "SchedulerOutput", wait_for_save: bool = True
    ) -> Generator[KVConnectorOutput, None, None]:
        output = KVConnectorOutput()

        # Update KVConnector with the KVConnector metadata forward().
        kv_connector = get_kv_transfer_group()
        assert isinstance(kv_connector, KVConnectorBase)
        assert scheduler_output.kv_connector_metadata is not None
        kv_connector.bind_connector_metadata(scheduler_output.kv_connector_metadata)

        # Background KV cache transfers happen here.
        # These transfers are designed to be async and the requests
        # involved may be disjoint from the running requests.
        # Do this here to save a collective_rpc.
        kv_connector.start_load_kv(get_forward_context())
        try:
            yield output
        finally:
            if wait_for_save:
                kv_connector.wait_for_save()

            output.finished_sending, output.finished_recving = (
                kv_connector.get_finished(scheduler_output.finished_req_ids)
            )
            output.invalid_block_ids = kv_connector.get_block_ids_with_load_errors()

            output.kv_connector_stats = (
                KVConnectorModelRunnerMixin.get_kv_connector_stats()
            )
            kv_connector.clear_connector_metadata()

    @staticmethod
    def get_kv_connector_stats() -> Optional[KVConnectorStats]:
        if has_kv_transfer_group():
            return get_kv_transfer_group().get_kv_connector_stats()
        return None

_get_kv_connector_output staticmethod

_get_kv_connector_output(
    scheduler_output: SchedulerOutput,
    wait_for_save: bool = True,
) -> Generator[KVConnectorOutput, None, None]
Source code in vllm/v1/worker/kv_connector_model_runner_mixin.py
@staticmethod
@contextmanager
def _get_kv_connector_output(
    scheduler_output: "SchedulerOutput", wait_for_save: bool = True
) -> Generator[KVConnectorOutput, None, None]:
    output = KVConnectorOutput()

    # Update KVConnector with the KVConnector metadata forward().
    kv_connector = get_kv_transfer_group()
    assert isinstance(kv_connector, KVConnectorBase)
    assert scheduler_output.kv_connector_metadata is not None
    kv_connector.bind_connector_metadata(scheduler_output.kv_connector_metadata)

    # Background KV cache transfers happen here.
    # These transfers are designed to be async and the requests
    # involved may be disjoint from the running requests.
    # Do this here to save a collective_rpc.
    kv_connector.start_load_kv(get_forward_context())
    try:
        yield output
    finally:
        if wait_for_save:
            kv_connector.wait_for_save()

        output.finished_sending, output.finished_recving = (
            kv_connector.get_finished(scheduler_output.finished_req_ids)
        )
        output.invalid_block_ids = kv_connector.get_block_ids_with_load_errors()

        output.kv_connector_stats = (
            KVConnectorModelRunnerMixin.get_kv_connector_stats()
        )
        kv_connector.clear_connector_metadata()

ensure_kv_transfer_shutdown staticmethod

ensure_kv_transfer_shutdown() -> None
Source code in vllm/v1/worker/kv_connector_model_runner_mixin.py
@staticmethod
def ensure_kv_transfer_shutdown() -> None:
    # has_kv_transfer_group can be None during interpreter shutdown.
    if has_kv_transfer_group and has_kv_transfer_group():
        ensure_kv_transfer_shutdown()

get_finished_kv_transfers staticmethod

get_finished_kv_transfers(
    scheduler_output: SchedulerOutput,
) -> tuple[Optional[set[str]], Optional[set[str]]]
Source code in vllm/v1/worker/kv_connector_model_runner_mixin.py
@staticmethod
def get_finished_kv_transfers(
    scheduler_output: "SchedulerOutput",
) -> tuple[Optional[set[str]], Optional[set[str]]]:
    if has_kv_transfer_group():
        return get_kv_transfer_group().get_finished(
            scheduler_output.finished_req_ids
        )
    return None, None

get_kv_connector_stats staticmethod

get_kv_connector_stats() -> Optional[KVConnectorStats]
Source code in vllm/v1/worker/kv_connector_model_runner_mixin.py
@staticmethod
def get_kv_connector_stats() -> Optional[KVConnectorStats]:
    if has_kv_transfer_group():
        return get_kv_transfer_group().get_kv_connector_stats()
    return None

kv_connector_no_forward staticmethod

kv_connector_no_forward(
    scheduler_output: SchedulerOutput,
    vllm_config: VllmConfig,
) -> ModelRunnerOutput
Source code in vllm/v1/worker/kv_connector_model_runner_mixin.py
@staticmethod
def kv_connector_no_forward(
    scheduler_output: "SchedulerOutput", vllm_config: VllmConfig
) -> ModelRunnerOutput:
    # KV send/recv even if no work to do.
    with (
        set_forward_context(None, vllm_config),
        KVConnectorModelRunnerMixin._get_kv_connector_output(
            scheduler_output, wait_for_save=False
        ) as kv_connector_output,
    ):
        pass

    if kv_connector_output.is_empty():
        return EMPTY_MODEL_RUNNER_OUTPUT

    output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT)
    output.kv_connector_output = kv_connector_output
    return output

maybe_get_kv_connector_output staticmethod

maybe_get_kv_connector_output(
    scheduler_output: SchedulerOutput,
) -> AbstractContextManager[Optional[KVConnectorOutput]]
Source code in vllm/v1/worker/kv_connector_model_runner_mixin.py
@staticmethod
def maybe_get_kv_connector_output(
    scheduler_output: "SchedulerOutput",
) -> AbstractContextManager[Optional[KVConnectorOutput]]:
    return (
        KVConnectorModelRunnerMixin._get_kv_connector_output(scheduler_output)
        if has_kv_transfer_group()
        else nullcontext()
    )

maybe_setup_kv_connector staticmethod

maybe_setup_kv_connector(scheduler_output: SchedulerOutput)
Source code in vllm/v1/worker/kv_connector_model_runner_mixin.py
@staticmethod
def maybe_setup_kv_connector(scheduler_output: "SchedulerOutput"):
    # Update KVConnector with the KVConnector metadata forward().
    if has_kv_transfer_group():
        kv_connector = get_kv_transfer_group()
        assert isinstance(kv_connector, KVConnectorBase)
        assert scheduler_output.kv_connector_metadata is not None
        kv_connector.bind_connector_metadata(scheduler_output.kv_connector_metadata)

        # Background KV cache transfers happen here.
        # These transfers are designed to be async and the requests
        # involved may be disjoint from the running requests.
        # Do this here to save a collective_rpc.
        kv_connector.start_load_kv(get_forward_context())

maybe_wait_for_kv_save staticmethod

maybe_wait_for_kv_save() -> None
Source code in vllm/v1/worker/kv_connector_model_runner_mixin.py
@staticmethod
def maybe_wait_for_kv_save() -> None:
    if has_kv_transfer_group():
        get_kv_transfer_group().wait_for_save()