vllm.attention.ops.flashmla ¶
_is_flashmla_available ¶
Source code in vllm/attention/ops/flashmla.py
flash_mla_sparse_prefill ¶
flash_mla_sparse_prefill(
q: Tensor,
kv: Tensor,
indices: Tensor,
sm_scale: float,
d_v: int = 512,
) -> tuple[Tensor, Tensor, Tensor]
Sparse attention prefill kernel
Args: - q: [s_q, h_q, d_qk], bfloat16 - kv: [s_kv, h_kv, d_qk], bfloat16 - indices: [s_q, h_kv, topk], int32. Invalid indices should be set to -1 or numbers >= s_kv - sm_scale: float - d_v: The dimension of value vectors. Can only be 512
- (output, max_logits, lse) About the definition of output, max_logits and lse, please refer to README.md
- output: [s_q, h_q, d_v], bfloat16
- max_logits: [s_q, h_q], float
- lse: [s_q, h_q], float, 2-based log-sum-exp
Source code in vllm/attention/ops/flashmla.py
flash_mla_with_kvcache ¶
flash_mla_with_kvcache(
q: Tensor,
k_cache: Tensor,
block_table: Tensor,
cache_seqlens: Tensor,
head_dim_v: int,
tile_scheduler_metadata: Tensor,
num_splits: Tensor,
softmax_scale: Optional[float] = None,
causal: bool = False,
descale_q: Optional[Tensor] = None,
descale_k: Optional[Tensor] = None,
is_fp8_kvcache: bool = False,
indices: Optional[Tensor] = None,
) -> tuple[Tensor, Tensor]
Arguments: - q: (batch_size, seq_len_q, num_heads_q, head_dim). - k_cache: (num_blocks, page_block_size, num_heads_k, head_dim). - block_table: (batch_size, max_num_blocks_per_seq), torch.int32. - cache_seqlens: (batch_size), torch.int32. - head_dim_v: Head dimension of v. - tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, returned by get_mla_metadata. - num_splits: (batch_size + 1), torch.int32, returned by get_mla_metadata. - softmax_scale: float. The scale of QK^T before applying softmax. Default to 1 / sqrt(head_dim). - causal: bool. Whether to apply causal attention mask. - descale_q: (batch_size), torch.float32. Descaling factors for Q, used for fp8 quantization. - descale_k: (batch_size), torch.float32. Descaling factors for K, used for fp8 quantization. - is_fp8_kvcache: bool. Whether the k_cache and v_cache are in fp8 format. For the format of FP8 KV cache, please refer to README.md - indices: (batch_size, seq_len_q, topk), torch.int32. If not None, sparse attention will be enabled, and only tokens in the indices
array will be attended to. Invalid indices should be set to -1 or numbers >= total_seq_len_kv. For details about how to set up indices
, please refer to README.md.
Returns: - out: (batch_size, seq_len_q, num_heads_q, head_dim_v). - softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
Source code in vllm/attention/ops/flashmla.py
116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 |
|
get_mla_metadata ¶
get_mla_metadata(
cache_seqlens: Tensor,
num_q_tokens_per_head_k: int,
num_heads_k: int,
num_heads_q: Optional[int] = None,
is_fp8_kvcache: bool = False,
topk: Optional[int] = None,
) -> tuple[Tensor, Tensor]
Arguments: - cache_seqlens: (batch_size), dtype torch.int32. - num_q_tokens_per_head_k: Equals to num_q_tokens_per_q_seq * num_heads_q // num_heads_k. - num_heads_k: The number of k heads. - num_heads_q: The number of q heads. This argument is optional when sparse attention is not enabled - is_fp8_kvcache: Whether the k_cache and v_cache are in fp8 format. - topk: If not None, sparse attention will be enabled, and only tokens in the indices
array passed to flash_mla_with_kvcache_sm90
will be attended to.
- tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32.
- num_splits: (batch_size + 1), dtype torch.int32.
Source code in vllm/attention/ops/flashmla.py
is_flashmla_dense_supported ¶
Return: is_supported_flag, unsupported_reason (optional).
Source code in vllm/attention/ops/flashmla.py
is_flashmla_sparse_supported ¶
Return: is_supported_flag, unsupported_reason (optional).