Skip to content

vllm.attention.layers.encoder_only_attention

EncoderOnlyAttention

Bases: Attention

Encoder attention is a special case that doesn't need a KV Cache.

Source code in vllm/attention/layers/encoder_only_attention.py
class EncoderOnlyAttention(Attention):
    """
    Encoder attention is a special case that doesn't need a KV Cache.
    """

    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        cache_config: Optional[CacheConfig] = None,
        attn_type: Optional[str] = None,
        **kwargs,
    ):
        dtype = torch.get_default_dtype()

        if cache_config is not None:
            kv_cache_dtype = cache_config.cache_dtype
            block_size = cache_config.block_size
        else:
            kv_cache_dtype = "auto"
            block_size = 16

        if envs.VLLM_USE_V1:
            underlying_attn_backend = get_attn_backend(
                head_size, dtype, kv_cache_dtype, block_size
            )

            attn_backend = create_encoder_only_attention_backend(
                underlying_attn_backend
            )
        else:
            # in v0 encoder only attention is handled inside the backends
            attn_backend = None

        if attn_type is not None:
            assert attn_type == AttentionType.ENCODER_ONLY, (
                "EncoderOnlyAttention only supports AttentionType.ENCODER_ONLY"
            )

        super().__init__(
            num_heads=num_heads,
            head_size=head_size,
            scale=scale,
            cache_config=cache_config,
            attn_backend=attn_backend,
            attn_type=AttentionType.ENCODER_ONLY,
            **kwargs,
        )

__init__

__init__(
    num_heads: int,
    head_size: int,
    scale: float,
    cache_config: Optional[CacheConfig] = None,
    attn_type: Optional[str] = None,
    **kwargs,
)
Source code in vllm/attention/layers/encoder_only_attention.py
def __init__(
    self,
    num_heads: int,
    head_size: int,
    scale: float,
    cache_config: Optional[CacheConfig] = None,
    attn_type: Optional[str] = None,
    **kwargs,
):
    dtype = torch.get_default_dtype()

    if cache_config is not None:
        kv_cache_dtype = cache_config.cache_dtype
        block_size = cache_config.block_size
    else:
        kv_cache_dtype = "auto"
        block_size = 16

    if envs.VLLM_USE_V1:
        underlying_attn_backend = get_attn_backend(
            head_size, dtype, kv_cache_dtype, block_size
        )

        attn_backend = create_encoder_only_attention_backend(
            underlying_attn_backend
        )
    else:
        # in v0 encoder only attention is handled inside the backends
        attn_backend = None

    if attn_type is not None:
        assert attn_type == AttentionType.ENCODER_ONLY, (
            "EncoderOnlyAttention only supports AttentionType.ENCODER_ONLY"
        )

    super().__init__(
        num_heads=num_heads,
        head_size=head_size,
        scale=scale,
        cache_config=cache_config,
        attn_backend=attn_backend,
        attn_type=AttentionType.ENCODER_ONLY,
        **kwargs,
    )

create_encoder_only_attention_backend cached

create_encoder_only_attention_backend(
    underlying_attn_backend: AttentionBackend,
) -> type[AttentionBackend]
Source code in vllm/attention/layers/encoder_only_attention.py
@functools.lru_cache
def create_encoder_only_attention_backend(
    underlying_attn_backend: AttentionBackend,
) -> type[AttentionBackend]:
    prefix = "EncoderOnlyAttention_"
    underlying_builder = underlying_attn_backend.get_builder_cls()

    class EncoderOnlyAttentionBuilder(underlying_builder):  # type: ignore
        def build(
            self,
            common_prefix_len: int,
            common_attn_metadata: CommonAttentionMetadata,
            fast_build: bool = False,
        ) -> AttentionMetadata:
            new_common_attn_metadata = copy(common_attn_metadata)
            new_common_attn_metadata.causal = False
            return super().build(
                common_prefix_len, new_common_attn_metadata, fast_build
            )

    attn_backend = subclass_attention_backend(
        name_prefix=prefix,
        attention_backend_cls=underlying_attn_backend,
        builder_cls=EncoderOnlyAttentionBuilder,
    )

    return attn_backend