state-spaces/mamba

GitHub: state-spaces/mamba

Mamba是一种新型状态空间模型架构,通过选择性SSM实现线性时间复杂度的序列建模,可作为Transformer的高效替代方案用于语言建模任务。

Stars: 17582 | Forks: 1634

# Mamba ![Mamba](assets/selection.png "选择性状态空间") ![Mamba-2](assets/ssd_algorithm.png "状态空间对偶模型") ![Mamba-3](assets/mamba3.png "推理优先状态空间模型") ## 关于 Mamba 是一种新的状态空间模型架构,在语言建模等信息密集型数据上表现出令人瞩目的性能,而在此之前的次二次模型则不及 Transformers。 它基于 [结构化状态空间模型](https://github.com/state-spaces/s4) 的进展路线, 并采用了具有高效的硬件感知设计和实现,灵感源自 [FlashAttention](https://github.com/Dao-AILab/flash-attention)。 ## 安装 首先安装 PyTorch,然后: - [可选] `pip install causal-conv1d>=1.4.0 --no-build-isolation`:Mamba 块内部使用的简单因果 Conv1d 层的高效实现。 - `pip install mamba-ssm --no-build-isolation`:核心 Mamba 包。 - `pip install mamba-ssm[causal-conv1d] --no-build-isolation`:安装核心 Mamba 包和 causal-conv1d。 需要使用 `--no-build-isolation`,以便 pip 使用您现有的启用 CUDA 的 PyTorch,而不是在隔离的构建环境中安装 torch-cpu。 其他要求: - Linux - NVIDIA GPU - PyTorch 1.12+ - CUDA 11.6+ 关于 AMD 显卡,请参阅下文的额外前置条件。 ## 用法 我们提供了多个层次的接口来使用 Mamba 模型。 ### 选择性 SSM Mamba 基于选择性 SSM 层,这是论文的重点(第 3 节;算法 2)。 源码:[ops/selective_scan_interface.py](mamba_ssm/ops/selective_scan_interface.py)。 ### Mamba 块 此代码库的主要模块是封装了选择性 SSM 的 Mamba 架构块。 源码:[modules/mamba_simple.py](mamba_ssm/modules/mamba_simple.py)。 用法: ``` import torch from mamba_ssm import Mamba batch, length, dim = 2, 64, 16 x = torch.randn(batch, length, dim).to("cuda") model = Mamba( # This module uses roughly 3 * expand * d_model^2 parameters d_model=dim, # Model dimension d_model d_state=16, # SSM state expansion factor d_conv=4, # Local convolution width expand=2, # Block expansion factor ).to("cuda") y = model(x) assert y.shape == x.shape ``` ### Mamba-2 Mamba-2 块的实现位于 [modules/mamba2.py](mamba_ssm/modules/mamba2.py)。 一个更简单的版本位于 [modules/mamba2_simple.py](mamba_ssm/modules/mamba2_simple.py) 其用法与 Mamba(-1) 类似: ``` from mamba_ssm import Mamba2 model = Mamba2( # This module uses roughly 3 * expand * d_model^2 parameters d_model=dim, # Model dimension d_model d_state=64, # SSM state expansion factor, typically 64 or 128 d_conv=4, # Local convolution width expand=2, # Block expansion factor ).to("cuda") y = model(x) assert y.shape == x.shape ``` #### SSD 内部 SSD 模块的最小版本(Mamba-2 论文中的清单 1),包含“离散”和“连续”SSM 版本之间的转换 位于 [modules/ssd_minimal.py](mamba_ssm/modules/ssd_minimal.py)。 ### Mamba 语言模型 最后,我们提供了一个完整语言模型的示例:深度序列模型主干(包含重复的 Mamba 块)+ 语言模型头。 源码:[models/mixer_seq_simple.py](mamba_ssm/models/mixer_seq_simple.py)。 这是一个如何将 Mamba 集成到端到端神经网络中的示例。 此示例用于下文的生成脚本中。 ## 预训练模型 预训练模型已上传至 [Hugging Face](https://huggingface.co/state-spaces):`mamba-130m`,`mamba-370m`, `mamba-790m`,`mamba-1.4b`,`mamba-2.8b`,`mamba2-130m`,`mamba2-370m`, `mamba2-780m`,`mamba2-1.3b`,`mamba2-2.7b`,`transformerpp-2.7b`,`mamba2attn-2.7b`,在 Pile 上基于 300B token 训练,以及 `mamba-2.8b-slimpj` (在 SlimPajama 数据集上基于 600B token 训练)。 这些模型将由下文的生成脚本自动下载。 这些模型是在 [Pile](https://huggingface.co/datasets/EleutherAI/pile) 上训练的,并遵循 GPT-3 描述的以及许多开源模型采用的标准模型维度: | Parameters | Layers | Model dim. | |------------|--------|------------| | 130M | 24 | 768 | | 370M | 48 | 1024 | | 790M | 48 | 1536 | | 1.4B | 48 | 2048 | | 2.8B | 64 | 2560 | (Mamba 的层数是相同大小 Transformer 的两倍,因为 Transformer 的每个“层”(MHA 块 + MLP 块)需要两个 Mamba 块。) 注意:这些是仅基于 300B token 训练的基础模型,没有任何形式的下游修改(如 instruction tuning 等)。 预期性能应与在类似数据上训练的其他架构相当或更好,但无法匹配更大或经过微调的模型。 ## 评估 要运行模型的零样本评估(对应于论文中的表 3), 我们使用 [lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness) 库。 1. 通过 `pip install lm-eval==0.4.2` 安装 `lm-evaluation-harness`。 2. 使用以下命令运行评估(更多文档请参见 [lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness/tree/big-refactor) 仓库): ``` lm_eval --model mamba_ssm --model_args pretrained=state-spaces/mamba-130m --tasks lambada_openai,hellaswag,piqa,arc_easy,arc_challenge,winogrande,openbookqa --device cuda --batch_size 256 python evals/lm_harness_eval.py --model hf --model_args pretrained=EleutherAI/pythia-160m --tasks lambada_openai,hellaswag,piqa,arc_easy,arc_challenge,winogrande --device cuda --batch_size 64 ``` 要复现博客文章中报告的 `mamba-2.8b-slimpj` 模型的结果: ``` lm_eval --model mamba_ssm --model_args pretrained=state-spaces/mamba-2.8b-slimpj --tasks boolq,piqa,hellaswag,winogrande,arc_easy,arc_challenge,openbookqa,race,truthfulqa_mc2 --device cuda --batch_size 256 lm_eval --model mamba_ssm --model_args pretrained=state-spaces/mamba-2.8b-slimpj --tasks mmlu --num_fewshot 5 --device cuda --batch_size 256 ``` 要在 Mamba-2 模型上运行评估,只需替换模型名称: ``` lm_eval --model mamba_ssm --model_args pretrained=state-spaces/mamba2-2.7b --tasks lambada_openai,hellaswag,piqa,arc_easy,arc_challenge,winogrande,openbookqa --device cuda --batch_size 256 lm_eval --model mamba_ssm --model_args pretrained=state-spaces/transformerpp-2.7b --tasks lambada_openai,hellaswag,piqa,arc_easy,arc_challenge,winogrande,openbookqa --device cuda --batch_size 256 lm_eval --model mamba_ssm --model_args pretrained=state-spaces/mamba2attn-2.7b --tasks lambada_openai,hellaswag,piqa,arc_easy,arc_challenge,winogrande,openbookqa --device cuda --batch_size 256 ``` 请注意,由于评估过程中的噪声,每个任务的结果可能与报告值有 0.1-0.3 的差异。 ## 推理 脚本 [benchmarks/benchmark_generation_mamba_simple.py](benchmarks/benchmark_generation_mamba_simple.py) 1. 从 Hugging Face Hub 自动加载模型, 2. 生成用户指定提示的补全, 3. 对此生成的推理速度进行基准测试。 其他可配置选项包括 top-p(nucleus sampling)概率和 softmax 温度。 ### 示例 要测试使用不同采样策略的生成延迟(例如 batch size = 1): ``` python benchmarks/benchmark_generation_mamba_simple.py --model-name "state-spaces/mamba-2.8b" --prompt "My cat wrote all this CUDA code for a new language model and" --topp 0.9 --temperature 0.7 --repetition-penalty 1.2 python benchmarks/benchmark_generation_mamba_simple.py --model-name "EleutherAI/pythia-2.8b" --prompt "My cat wrote all this CUDA code for a new language model and" --topp 0.9 --temperature 0.7 --repetition-penalty 1.2 python benchmarks/benchmark_generation_mamba_simple.py --model-name "state-spaces/mamba-2.8b" --prompt "My cat wrote all this CUDA code for a new language model and" --minp 0.05 --topk 0 --temperature 0.7 --repetition-penalty 1.2 ``` 要测试使用随机提示的生成吞吐量(例如大 batch size): ``` python benchmarks/benchmark_generation_mamba_simple.py --model-name "state-spaces/mamba-2.8b" --batch 64 python benchmarks/benchmark_generation_mamba_simple.py --model-name "EleutherAI/pythia-2.8b" --batch 64 ``` 对于 Mamba-2,只需更改模型名称: ``` python benchmarks/benchmark_generation_mamba_simple.py --model-name "state-spaces/mamba2-2.7b" --prompt "My cat wrote all this CUDA code for a new language model and" --topp 0.9 --temperature 0.7 --repetition-penalty 1.2 ``` ## 故障排除 ### 精度 我们的模型使用 PyTorch [AMP](https://pytorch.org/docs/stable/amp.html) 进行混合精度训练。AMP 将模型参数保持在 float32 中,并在必要时转换为半精度。 另一方面,其他框架如 DeepSpeed 将参数存储在 float16 中,并在必要时向上转换(例如用于优化器累加)。 我们观察到主模型参数可能需要更高的精度,因为 SSM 对其循环动态很敏感。如果您遇到不稳定的情况, 作为第一步,请尝试使用以 fp32 存储参数的框架(如 AMP)。 ### 初始化 模型的某些部分具有从 S4 模型的先前工作中继承的初始化方式。 例如(https://github.com/state-spaces/mamba/blob/f0affcf69f06d1d06cef018ff640bf080a11c421/mamba_ssm/modules/mamba_simple.py#L102),$\Delta$ 参数通过初始化其线性投影的偏置来具有目标范围。 但是,某些框架可能具有初始化后钩子(例如将 `nn.Linear` 模块中的所有偏置项设置为零)。 如果是这种情况,您可能需要添加自定义逻辑(例如这 [行](https://github.com/state-spaces/mamba/blob/f0affcf69f06d1d06cef018ff640bf080a11c421/mamba_ssm/modules/mamba_simple.py#L104) 在我们的训练器中关闭了重新初始化,但在任何其他框架中将无操作) 这是特定于训练框架的。 ## AMD 显卡的额外前置条件 ### 修补 ROCm 如果您使用的是 ROCm 6.0,请运行以下步骤以避免编译期间的错误。这对于 ROCm 6.1 及更高版本不是必需的。 1. 找到您的 ROCm 安装目录。通常位于 `/opt/rocm/`,但可能因您的安装而异。 2. 应用补丁。如果遇到权限问题,请使用 `sudo` 运行。 patch /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h < rocm_patch/rocm6_0.patch ## 引用 如果您使用此代码库,或认为我们的工作有价值,请引用 Mamba: ``` @article{mamba, title={Mamba: Linear-Time Sequence Modeling with Selective State Spaces}, author={Gu, Albert and Dao, Tri}, journal={arXiv preprint arXiv:2312.00752}, year={2023} } @inproceedings{mamba2, title={Transformers are {SSM}s: Generalized Models and Efficient Algorithms Through Structured State Space Duality}, author={Dao, Tri and Gu, Albert}, booktitle={International Conference on Machine Learning (ICML)}, year={2024} } @misc{lahoti2026mamba3improvedsequencemodeling, title={Mamba-3: Improved Sequence Modeling using State Space Principles}, author={Aakash Lahoti and Kevin Y. Li and Berlin Chen and Caitlin Wang and Aviv Bick and J. Zico Kolter and Tri Dao and Albert Gu}, year={2026}, eprint={2603.15569}, archivePrefix={arXiv}, primaryClass={cs.LG}, url={https://arxiv.org/abs/2603.15569}, } ```
标签:CUDA加速, DLL 劫持, Mamba, PyTorch, SSM, Transformer替代方案, Vectored Exception Handling, 凭据扫描, 因果卷积, 大语言模型, 开源模型, 深度学习, 状态空间模型, 生成式AI, 硬件感知算法, 神经网络架构, 索引, 逆向工具, 选择性状态空间, 长序列建模, 高效推理