deepseek-ai/FlashMLA

GitHub: deepseek-ai/FlashMLA

FlashMLA 是 DeepSeek 开源的高效多头潜在注意力核心库,为 DeepSeek-V3 系列模型提供推理加速支持。

Stars: 12513 | Forks: 997

# FlashMLA ## 介绍 FlashMLA 是 DeepSeek 的优化 attention kernel 库,为 [DeepSeek-V3](https://github.com/deepseek-ai/DeepSeek-V3) 和 [DeepSeek-V3.2-Exp](https://github.com/deepseek-ai/DeepSeek-V3.2-Exp) 模型提供支持。此代码库包含以下实现: **Sparse Attention Kernels** *这些 kernel 支持 DeepSeek Sparse Attention (DSA),如[本文](https://github.com/deepseek-ai/DeepSeek-V3.2-Exp)所述。* - 用于 prefill 阶段的 Token-level sparse attention - 用于 decoding 阶段的 Token-level sparse attention,支持 FP8 KV cache **Dense Attention Kernels** - 用于 prefill 阶段的 Dense attention - 用于 decoding 阶段的 Dense attention ## 新闻 - **2025.09.29 发布 Sparse Attention Kernels**:随着 [DeepSeek-V3.2](https://github.com/deepseek-ai/DeepSeek-V3.2-Exp) 的发布,我们发布了相应的 token-level sparse attention kernel。这些 kernel 支持模型的 DeepSeek Sparse Attention (DSA),在 prefill 期间可达到 640 TFlops,在 decoding 期间可达到 410 TFlops。我们还发布了一篇关于我们新 FP8 sparse decoding kernel 的深度解析博客。点击[这里](docs/20250929-hopper-fp8-sparse-deep-dive.md)查看。 - **2025.08.01 SM100 上的 MHA Kernels**:感谢 NVIDIA 提交的关于 SM100 上 MHA forward / backward kernel 的 [PR](https://github.com/deepseek-ai/FlashMLA/pull/76)! - **2025.04.22 深度解析博客**:我们非常乐意分享新 FlashMLA kernel 背后的技术细节!请在此处查看我们的深度解析文章[点击查看](docs/20250422-new-kernel-deep-dive.md)。 - **2025.04.22 性能更新**:我们很高兴地宣布 Flash MLA 的新版本发布,该版本在计算密集型工作负载上实现了 5% ~ 15% 的性能提升,在 NVIDIA H800 SXM5 GPU 上达到了 660 TFlops。新版本的接口与旧版本完全兼容。只需升级到新版本即可立即获得性能提升!🚀🚀🚀 ## 性能 #### 测试 & 基准测试 MLA decoding (Sparse & Dense): ``` python tests/test_flash_mla_dense_decoding.py python tests/test_flash_mla_sparse_decoding.py ``` Dense MLA decoding kernel 在配备 CUDA 12.8 的 H800 SXM5 上,于内存受限配置下实现了高达 3000 GB/s 的性能,于计算受限配置下实现了 660 TFLOPS。Token-level sparse MLA decoding kernel(使用 FP8 KV cache,同时以 bfloat16 执行矩阵乘法)在配备 CUDA 12.8 的 H800 SXM5 上于计算受限配置下实现了 410 TFLOPS,并在 B200 上实现了高达 350 TFlops(尚未完全优化)。 #### 测试 & 基准测试 MHA prefill (Dense): ``` python tests/test_fmha_sm100.py ``` 根据 NVIDIA 的报告,它在 B200 上的 forward 计算中达到了 1460 TFlops,在 backward 计算中达到了 1000 TFlops。 #### 测试 & 基准测试 MLA prefill (Sparse): ``` python tests/test_flash_mla_sparse_prefill.py ``` 它在配备 CUDA 12.8 的 H800 SXM5 上的 forward 计算中达到了 660 TFlops,并在配备 CUDA 12.9 的 B200 上达到了 1450 TFlops。 ## 需求 - SM90 / SM100(参见下方的支持矩阵) - CUDA 12.8 及更高版本(SM100 kernel 需要 CUDA 12.9+) - PyTorch 2.0 及更高版本 支持矩阵: | Kernel | GPU 架构 | MLA 模式 [2] | KVCache 格式 | | :---: | :---: | :---: | :---: | | Dense Decoding | SM90 | MQA | BF16 | | Sparse Decoding | SM90 & SM100 | MQA | FP8 [1] | | Dense Prefill | SM100 | MHA | | | Sparse Prefill | SM90 & SM100 | MQA | | [1]: 有关使用 FP8 KV cache 的更多详细信息,请参阅下文文档。 [2]: 此处“MLA Mode”是指用于 MLA 计算的模式。MQA 代表 Multi-Query Attention 模式(即 `head_dim_k` = 576 且 `head_dim_v` = 512),而 MHA 代表 Multi-Head Attention 模式(即 `head_dim_k` = 192 / 128 且 `head_dim_v` = 128)。有关这些模式的详细说明,请参阅 [DeepSeek V3.2 论文](https://github.com/deepseek-ai/DeepSeek-V3.2-Exp)的附录。 ## 安装 ``` git clone https://github.com/deepseek-ai/FlashMLA.git flash-mla cd flash-mla git submodule update --init --recursive pip install -v . ``` ## 用法 ### MLA 解码 要使用 MLA decoding kernel,请在 decoding 循环之前调用一次 get_mla_metadata 以获取 tile scheduler 元数据。然后,在每个 decoding 步骤中调用 flash_mla_with_kvcache。例如: ``` from flash_mla import get_mla_metadata, flash_mla_with_kvcache tile_scheduler_metadata, num_splits = get_mla_metadata( cache_seqlens, s_q * h_q // h_kv, h_kv, h_q, is_fp8, topk, ) for i in range(num_layers): ... o_i, lse_i = flash_mla_with_kvcache( q_i, kvcache_i, block_table, cache_seqlens, dv, tile_scheduler_metadata, num_splits, is_causal, is_fp8_kvcache, indices, ) ... ``` 其中 - `s_q` 是每个 q 序列的 q token 数量。如果 MTP (speculative decoding) 被禁用,则应将其设置为 1。 - `h_kv` 是 key-value head 的数量。 - `h_q` 是 query head 的数量。 **FP8 KV Cache:** 如果 `is_fp8_kvcache` 设置为 `True`,kernel 将以“FP8 with scale”格式(如下所述)读取 KV cache。它将 cache 反量化为 bfloat16 并以 bfloat16 执行 attention 计算。输出也是 bfloat16 格式。 在“FP8 with scale”格式中,每个 token 的 KV cache 为 656 字节,结构如下: - **前 512 字节:** “quantized NoPE”部分,包含 512 个 `float8_e4m3` 值。 - **接下来的 16 字节:** 缩放因子,包含 4 个 `float32` 值。第一个 `float32` 是前 128 个 `float8_e4m3` 值的缩放因子,第二个是接下来的 128 个,依此类推。 - **最后 128 字节:** “RoPE”部分,包含 64 个 `bfloat16` 值。为了确保精度,此部分未进行量化。 有关量化和反量化的详细信息,请参阅 `tests/quant.py`。 **Sparse Attention (`indices` tensor):** `indices` tensor(如果提供)通过指示 kernel 仅计算指定 token 的 attention 来启用 token-level sparse attention。 - **形状:** `indices` 应为一个形状为 `(batch_size, seq_len_q, topk)` 的 3D tensor。 - **格式:** `indices_in_kvcache[i][j][k] = (token t 所在的 page block 索引) * page_block_size + (token t 在 page block 内的偏移量)`,其中 `t` 是第 i 个 batch 中第 j 个 query 序列的第 k 个 token。由于 page block 的索引已经编码在 `indices_in_kvcache` 中,因此 kernel 不需要 `block_table` 参数。 - **无效条目:** 将无效索引设置为 `-1`。 **返回值:** Kernel 返回 `(out, lse)`,其中: - `out` 是 attention 结果。 - `lse` 是每个 query head 的 attention 分数的 log-sum-exp 值。 完整示例请参见 `tests/test_flash_mla_decoding.py`。 ### 稀疏 MLA Prefill 对于 sparse MLA prefill kernel,请使用以下参数直接调用 `flash_mla_sparse_fwd`: - `q`:形状为 `[s_q, h_q, d_qk]` 的 Query tensor - `kv`:形状为 `[s_kv, h_kv, d_qk]` 的 Key-Value tensor - `indices`:形状为 `[s_q, h_kv, topk]` 的 Indices tensor - `sm_scale`:一个标量值 **关于批处理的说明:** 此 kernel 不支持 batch 维度。对于多 batch 推理,请重塑输入 tensor 并调整 `indices` 参数以模拟批处理。 **无效索引:** 将 `indices` 中的无效条目设置为 `-1` 或任何 `>= s_kv` 的数字。 **返回值和等效 PyTorch 代码:** Kernel 返回 `(out, max_logits, lse)`。这等效于以下 PyTorch 操作: ``` Q: [s_q, h_q, d_qk], bfloat16 kv: [s_kv, h_kv, d_qk], bfloat16 indices: [s_q, h_kv, topk], int32 kv = kv.squeeze(1) # [s_kv, d_qk], h_kv must be 1 indices = indices.squeeze(1) # [s_q, topk] focused_kv = kv[indices] # For the i-th sequence (s_q), the corresponding KV tokens are selected from the KV cache based on indices[i, :]. This operation results in a tensor of shape [s_q, topk, d_qk]. P = (Q @ focused_kv.transpose(-1, -2)) * sm_scale * math.log2(math.e) # [s_q, h_q, topk] max_logits = P.max(dim=-1) # [s_q, h_q] lse = log2sumexp2(P, dim=-1, base=2) # [s_q, h_q],"log2sumexp2" means that the exponentiation and logarithm are base-2 S = exp2(P - lse) # [s_q, h_q, topk] out = S @ focused_kv # [s_q, h_q, d_qk] return (out, max_logits, lse) ``` 完整示例请参见 `tests/test_flash_mla_prefill.py`。 ### 密集 MHA Prefill 此 kernel 实现了标准的 dense Multi-Head Attention (MHA) forward 和 backward 操作。可以通过以下方式调用: - `flash_attn_varlen_func` - `flash_attn_varlen_qkvpacked_func` - `flash_attn_varlen_kvpacked_func` 其用法类似于 `flash_attn` 包。完整示例请参见 `tests/test_fmha_sm100.py`。 ## 致谢 FlashMLA 受到了 [FlashAttention 2&3](https://github.com/dao-AILab/flash-attention/) 和 [cutlass](https://github.com/nvidia/cutlass) 项目的启发。 ## 社区支持 ### MetaX 对于 MetaX GPU,请访问官方网站:[MetaX](https://www.metax-tech.com)。 相应的 FlashMLA 版本可在以下位置找到:[MetaX-MACA/FlashMLA](https://github.com/MetaX-MACA/FlashMLA) ### Moore Threads 对于 Moore Threads GPU,请访问官方网站:[Moore Threads](https://www.mthreads.com/)。 相应的 FlashMLA 版本可在 GitHub 上找到:[MooreThreads/MT-flashMLA](https://github.com/MooreThreads/MT-flashMLA)。 ### Hygon DCU 对于 Hygon DCU,请访问官方网站:[Hygon Developer](https://developer.sourcefind.cn/)。 相应的 FlashMLA 版本可在此处获取:[OpenDAS/MLAttention](https://developer.sourcefind.cn/codes/OpenDAS/MLAttention)。 ### Intellifusion 对于 Intellifusion NNP,请访问官方网站:[Intellifusion](https://www.intellif.com)。 相应的 FlashMLA 版本可在 Gitee 上找到:[Intellifusion/tyllm](https://gitee.com/Intellifusion_2025/tyllm/blob/master/python/tylang/flash_mla.py)。 ### Iluvatar Corex 对于 Iluvatar Corex GPU,请访问官方网站:[Iluvatar Corex](https://www.iluvatar.com)。 相应的 FlashMLA 版本可在 GitHub 上找到:[Deep-Spark/FlashMLA](https://github.com/Deep-Spark/FlashMLA/tree/iluvatar_flashmla) ### AMD Instinct 对于 AMD Instinct GPU,请访问官方网站:[AMD Instinct](https://www.amd.com/en/products/accelerators/instinct.html)。 相应的 FlashMLA 版本可在以下位置找到:[AITER/MLA](https://github.com/ROCm/aiter/blob/main/aiter/mla.py) ## 引用 ``` @misc{flashmla2025, title={FlashMLA: Efficient Multi-head Latent Attention Kernels}, author={Jiashi Li, Shengyu Liu}, year={2025}, publisher = {GitHub}, howpublished = {\url{https://github.com/deepseek-ai/FlashMLA}}, } ```
标签:CUDA内核, DeepSeek, DeepSeek-V3, DSA, FlashMLA, FP8, H800, Hopper架构, HPC, KV Cache, MLA, Token级稀疏注意力, Vectored Exception Handling, 人工智能, 凭据扫描, 多头潜在注意力, 大模型推理, 底层优化, 显存优化, 注意力机制优化, 深度学习, 熵值分析, 用户模式Hook绕过, 稀疏注意力, 稠密注意力, 算子开发, 解码, 计算性能优化, 逆向工具, 预填充, 高性能计算