deepseek-ai/FlashMLA
GitHub: deepseek-ai/FlashMLA
FlashMLA 是 DeepSeek 开源的高效多头潜在注意力核心库,为 DeepSeek-V3 系列模型提供推理加速支持。
Stars: 12513 | Forks: 997
# FlashMLA
## 介绍
FlashMLA 是 DeepSeek 的优化 attention kernel 库,为 [DeepSeek-V3](https://github.com/deepseek-ai/DeepSeek-V3) 和 [DeepSeek-V3.2-Exp](https://github.com/deepseek-ai/DeepSeek-V3.2-Exp) 模型提供支持。此代码库包含以下实现:
**Sparse Attention Kernels**
*这些 kernel 支持 DeepSeek Sparse Attention (DSA),如[本文](https://github.com/deepseek-ai/DeepSeek-V3.2-Exp)所述。*
- 用于 prefill 阶段的 Token-level sparse attention
- 用于 decoding 阶段的 Token-level sparse attention,支持 FP8 KV cache
**Dense Attention Kernels**
- 用于 prefill 阶段的 Dense attention
- 用于 decoding 阶段的 Dense attention
## 新闻
- **2025.09.29 发布 Sparse Attention Kernels**:随着 [DeepSeek-V3.2](https://github.com/deepseek-ai/DeepSeek-V3.2-Exp) 的发布,我们发布了相应的 token-level sparse attention kernel。这些 kernel 支持模型的 DeepSeek Sparse Attention (DSA),在 prefill 期间可达到 640 TFlops,在 decoding 期间可达到 410 TFlops。我们还发布了一篇关于我们新 FP8 sparse decoding kernel 的深度解析博客。点击[这里](docs/20250929-hopper-fp8-sparse-deep-dive.md)查看。
- **2025.08.01 SM100 上的 MHA Kernels**:感谢 NVIDIA 提交的关于 SM100 上 MHA forward / backward kernel 的 [PR](https://github.com/deepseek-ai/FlashMLA/pull/76)!
- **2025.04.22 深度解析博客**:我们非常乐意分享新 FlashMLA kernel 背后的技术细节!请在此处查看我们的深度解析文章[点击查看](docs/20250422-new-kernel-deep-dive.md)。
- **2025.04.22 性能更新**:我们很高兴地宣布 Flash MLA 的新版本发布,该版本在计算密集型工作负载上实现了 5% ~ 15% 的性能提升,在 NVIDIA H800 SXM5 GPU 上达到了 660 TFlops。新版本的接口与旧版本完全兼容。只需升级到新版本即可立即获得性能提升!🚀🚀🚀
## 性能
#### 测试 & 基准测试 MLA decoding (Sparse & Dense):
```
python tests/test_flash_mla_dense_decoding.py
python tests/test_flash_mla_sparse_decoding.py
```
Dense MLA decoding kernel 在配备 CUDA 12.8 的 H800 SXM5 上,于内存受限配置下实现了高达 3000 GB/s 的性能,于计算受限配置下实现了 660 TFLOPS。Token-level sparse MLA decoding kernel(使用 FP8 KV cache,同时以 bfloat16 执行矩阵乘法)在配备 CUDA 12.8 的 H800 SXM5 上于计算受限配置下实现了 410 TFLOPS,并在 B200 上实现了高达 350 TFlops(尚未完全优化)。
#### 测试 & 基准测试 MHA prefill (Dense):
```
python tests/test_fmha_sm100.py
```
根据 NVIDIA 的报告,它在 B200 上的 forward 计算中达到了 1460 TFlops,在 backward 计算中达到了 1000 TFlops。
#### 测试 & 基准测试 MLA prefill (Sparse):
```
python tests/test_flash_mla_sparse_prefill.py
```
它在配备 CUDA 12.8 的 H800 SXM5 上的 forward 计算中达到了 660 TFlops,并在配备 CUDA 12.9 的 B200 上达到了 1450 TFlops。
## 需求
- SM90 / SM100(参见下方的支持矩阵)
- CUDA 12.8 及更高版本(SM100 kernel 需要 CUDA 12.9+)
- PyTorch 2.0 及更高版本
支持矩阵:
| Kernel | GPU 架构 | MLA 模式 [2] | KVCache 格式 |
| :---: | :---: | :---: | :---: |
| Dense Decoding | SM90 | MQA | BF16 |
| Sparse Decoding | SM90 & SM100 | MQA | FP8 [1] |
| Dense Prefill | SM100 | MHA | |
| Sparse Prefill | SM90 & SM100 | MQA | |
[1]: 有关使用 FP8 KV cache 的更多详细信息,请参阅下文文档。
[2]: 此处“MLA Mode”是指用于 MLA 计算的模式。MQA 代表 Multi-Query Attention 模式(即 `head_dim_k` = 576 且 `head_dim_v` = 512),而 MHA 代表 Multi-Head Attention 模式(即 `head_dim_k` = 192 / 128 且 `head_dim_v` = 128)。有关这些模式的详细说明,请参阅 [DeepSeek V3.2 论文](https://github.com/deepseek-ai/DeepSeek-V3.2-Exp)的附录。
## 安装
```
git clone https://github.com/deepseek-ai/FlashMLA.git flash-mla
cd flash-mla
git submodule update --init --recursive
pip install -v .
```
## 用法
### MLA 解码
要使用 MLA decoding kernel,请在 decoding 循环之前调用一次 get_mla_metadata 以获取 tile scheduler 元数据。然后,在每个 decoding 步骤中调用 flash_mla_with_kvcache。例如:
```
from flash_mla import get_mla_metadata, flash_mla_with_kvcache
tile_scheduler_metadata, num_splits = get_mla_metadata(
cache_seqlens,
s_q * h_q // h_kv,
h_kv,
h_q,
is_fp8,
topk,
)
for i in range(num_layers):
...
o_i, lse_i = flash_mla_with_kvcache(
q_i, kvcache_i, block_table, cache_seqlens, dv,
tile_scheduler_metadata, num_splits,
is_causal, is_fp8_kvcache, indices,
)
...
```
其中
- `s_q` 是每个 q 序列的 q token 数量。如果 MTP (speculative decoding) 被禁用,则应将其设置为 1。
- `h_kv` 是 key-value head 的数量。
- `h_q` 是 query head 的数量。
**FP8 KV Cache:**
如果 `is_fp8_kvcache` 设置为 `True`,kernel 将以“FP8 with scale”格式(如下所述)读取 KV cache。它将 cache 反量化为 bfloat16 并以 bfloat16 执行 attention 计算。输出也是 bfloat16 格式。
在“FP8 with scale”格式中,每个 token 的 KV cache 为 656 字节,结构如下:
- **前 512 字节:** “quantized NoPE”部分,包含 512 个 `float8_e4m3` 值。
- **接下来的 16 字节:** 缩放因子,包含 4 个 `float32` 值。第一个 `float32` 是前 128 个 `float8_e4m3` 值的缩放因子,第二个是接下来的 128 个,依此类推。
- **最后 128 字节:** “RoPE”部分,包含 64 个 `bfloat16` 值。为了确保精度,此部分未进行量化。
有关量化和反量化的详细信息,请参阅 `tests/quant.py`。
**Sparse Attention (`indices` tensor):**
`indices` tensor(如果提供)通过指示 kernel 仅计算指定 token 的 attention 来启用 token-level sparse attention。
- **形状:** `indices` 应为一个形状为 `(batch_size, seq_len_q, topk)` 的 3D tensor。
- **格式:** `indices_in_kvcache[i][j][k] = (token t 所在的 page block 索引) * page_block_size + (token t 在 page block 内的偏移量)`,其中 `t` 是第 i 个 batch 中第 j 个 query 序列的第 k 个 token。由于 page block 的索引已经编码在 `indices_in_kvcache` 中,因此 kernel 不需要 `block_table` 参数。
- **无效条目:** 将无效索引设置为 `-1`。
**返回值:**
Kernel 返回 `(out, lse)`,其中:
- `out` 是 attention 结果。
- `lse` 是每个 query head 的 attention 分数的 log-sum-exp 值。
完整示例请参见 `tests/test_flash_mla_decoding.py`。
### 稀疏 MLA Prefill
对于 sparse MLA prefill kernel,请使用以下参数直接调用 `flash_mla_sparse_fwd`:
- `q`:形状为 `[s_q, h_q, d_qk]` 的 Query tensor
- `kv`:形状为 `[s_kv, h_kv, d_qk]` 的 Key-Value tensor
- `indices`:形状为 `[s_q, h_kv, topk]` 的 Indices tensor
- `sm_scale`:一个标量值
**关于批处理的说明:** 此 kernel 不支持 batch 维度。对于多 batch 推理,请重塑输入 tensor 并调整 `indices` 参数以模拟批处理。
**无效索引:** 将 `indices` 中的无效条目设置为 `-1` 或任何 `>= s_kv` 的数字。
**返回值和等效 PyTorch 代码:**
Kernel 返回 `(out, max_logits, lse)`。这等效于以下 PyTorch 操作:
```
Q: [s_q, h_q, d_qk], bfloat16
kv: [s_kv, h_kv, d_qk], bfloat16
indices: [s_q, h_kv, topk], int32
kv = kv.squeeze(1) # [s_kv, d_qk], h_kv must be 1
indices = indices.squeeze(1) # [s_q, topk]
focused_kv = kv[indices] # For the i-th sequence (s_q), the corresponding KV tokens are selected from the KV cache based on indices[i, :]. This operation results in a tensor of shape [s_q, topk, d_qk].
P = (Q @ focused_kv.transpose(-1, -2)) * sm_scale * math.log2(math.e) # [s_q, h_q, topk]
max_logits = P.max(dim=-1) # [s_q, h_q]
lse = log2sumexp2(P, dim=-1, base=2) # [s_q, h_q],"log2sumexp2" means that the exponentiation and logarithm are base-2
S = exp2(P - lse) # [s_q, h_q, topk]
out = S @ focused_kv # [s_q, h_q, d_qk]
return (out, max_logits, lse)
```
完整示例请参见 `tests/test_flash_mla_prefill.py`。
### 密集 MHA Prefill
此 kernel 实现了标准的 dense Multi-Head Attention (MHA) forward 和 backward 操作。可以通过以下方式调用:
- `flash_attn_varlen_func`
- `flash_attn_varlen_qkvpacked_func`
- `flash_attn_varlen_kvpacked_func`
其用法类似于 `flash_attn` 包。完整示例请参见 `tests/test_fmha_sm100.py`。
## 致谢
FlashMLA 受到了 [FlashAttention 2&3](https://github.com/dao-AILab/flash-attention/) 和 [cutlass](https://github.com/nvidia/cutlass) 项目的启发。
## 社区支持
### MetaX
对于 MetaX GPU,请访问官方网站:[MetaX](https://www.metax-tech.com)。
相应的 FlashMLA 版本可在以下位置找到:[MetaX-MACA/FlashMLA](https://github.com/MetaX-MACA/FlashMLA)
### Moore Threads
对于 Moore Threads GPU,请访问官方网站:[Moore Threads](https://www.mthreads.com/)。
相应的 FlashMLA 版本可在 GitHub 上找到:[MooreThreads/MT-flashMLA](https://github.com/MooreThreads/MT-flashMLA)。
### Hygon DCU
对于 Hygon DCU,请访问官方网站:[Hygon Developer](https://developer.sourcefind.cn/)。
相应的 FlashMLA 版本可在此处获取:[OpenDAS/MLAttention](https://developer.sourcefind.cn/codes/OpenDAS/MLAttention)。
### Intellifusion
对于 Intellifusion NNP,请访问官方网站:[Intellifusion](https://www.intellif.com)。
相应的 FlashMLA 版本可在 Gitee 上找到:[Intellifusion/tyllm](https://gitee.com/Intellifusion_2025/tyllm/blob/master/python/tylang/flash_mla.py)。
### Iluvatar Corex
对于 Iluvatar Corex GPU,请访问官方网站:[Iluvatar Corex](https://www.iluvatar.com)。
相应的 FlashMLA 版本可在 GitHub 上找到:[Deep-Spark/FlashMLA](https://github.com/Deep-Spark/FlashMLA/tree/iluvatar_flashmla)
### AMD Instinct
对于 AMD Instinct GPU,请访问官方网站:[AMD Instinct](https://www.amd.com/en/products/accelerators/instinct.html)。
相应的 FlashMLA 版本可在以下位置找到:[AITER/MLA](https://github.com/ROCm/aiter/blob/main/aiter/mla.py)
## 引用
```
@misc{flashmla2025,
title={FlashMLA: Efficient Multi-head Latent Attention Kernels},
author={Jiashi Li, Shengyu Liu},
year={2025},
publisher = {GitHub},
howpublished = {\url{https://github.com/deepseek-ai/FlashMLA}},
}
```
标签:CUDA内核, DeepSeek, DeepSeek-V3, DSA, FlashMLA, FP8, H800, Hopper架构, HPC, KV Cache, MLA, Token级稀疏注意力, Vectored Exception Handling, 人工智能, 凭据扫描, 多头潜在注意力, 大模型推理, 底层优化, 显存优化, 注意力机制优化, 深度学习, 熵值分析, 用户模式Hook绕过, 稀疏注意力, 稠密注意力, 算子开发, 解码, 计算性能优化, 逆向工具, 预填充, 高性能计算