state-spaces/mamba
GitHub: state-spaces/mamba
Mamba是一种新型状态空间模型架构,通过选择性SSM实现线性时间复杂度的序列建模,可作为Transformer的高效替代方案用于语言建模任务。
Stars: 17582 | Forks: 1634
# Mamba



## 关于
Mamba 是一种新的状态空间模型架构,在语言建模等信息密集型数据上表现出令人瞩目的性能,而在此之前的次二次模型则不及 Transformers。
它基于 [结构化状态空间模型](https://github.com/state-spaces/s4) 的进展路线,
并采用了具有高效的硬件感知设计和实现,灵感源自 [FlashAttention](https://github.com/Dao-AILab/flash-attention)。
## 安装
首先安装 PyTorch,然后:
- [可选] `pip install causal-conv1d>=1.4.0 --no-build-isolation`:Mamba 块内部使用的简单因果 Conv1d 层的高效实现。
- `pip install mamba-ssm --no-build-isolation`:核心 Mamba 包。
- `pip install mamba-ssm[causal-conv1d] --no-build-isolation`:安装核心 Mamba 包和 causal-conv1d。
需要使用 `--no-build-isolation`,以便 pip 使用您现有的启用 CUDA 的 PyTorch,而不是在隔离的构建环境中安装 torch-cpu。
其他要求:
- Linux
- NVIDIA GPU
- PyTorch 1.12+
- CUDA 11.6+
关于 AMD 显卡,请参阅下文的额外前置条件。
## 用法
我们提供了多个层次的接口来使用 Mamba 模型。
### 选择性 SSM
Mamba 基于选择性 SSM 层,这是论文的重点(第 3 节;算法 2)。
源码:[ops/selective_scan_interface.py](mamba_ssm/ops/selective_scan_interface.py)。
### Mamba 块
此代码库的主要模块是封装了选择性 SSM 的 Mamba 架构块。
源码:[modules/mamba_simple.py](mamba_ssm/modules/mamba_simple.py)。
用法:
```
import torch
from mamba_ssm import Mamba
batch, length, dim = 2, 64, 16
x = torch.randn(batch, length, dim).to("cuda")
model = Mamba(
# This module uses roughly 3 * expand * d_model^2 parameters
d_model=dim, # Model dimension d_model
d_state=16, # SSM state expansion factor
d_conv=4, # Local convolution width
expand=2, # Block expansion factor
).to("cuda")
y = model(x)
assert y.shape == x.shape
```
### Mamba-2
Mamba-2 块的实现位于 [modules/mamba2.py](mamba_ssm/modules/mamba2.py)。
一个更简单的版本位于 [modules/mamba2_simple.py](mamba_ssm/modules/mamba2_simple.py)
其用法与 Mamba(-1) 类似:
```
from mamba_ssm import Mamba2
model = Mamba2(
# This module uses roughly 3 * expand * d_model^2 parameters
d_model=dim, # Model dimension d_model
d_state=64, # SSM state expansion factor, typically 64 or 128
d_conv=4, # Local convolution width
expand=2, # Block expansion factor
).to("cuda")
y = model(x)
assert y.shape == x.shape
```
#### SSD
内部 SSD 模块的最小版本(Mamba-2 论文中的清单 1),包含“离散”和“连续”SSM 版本之间的转换
位于 [modules/ssd_minimal.py](mamba_ssm/modules/ssd_minimal.py)。
### Mamba 语言模型
最后,我们提供了一个完整语言模型的示例:深度序列模型主干(包含重复的 Mamba 块)+ 语言模型头。
源码:[models/mixer_seq_simple.py](mamba_ssm/models/mixer_seq_simple.py)。
这是一个如何将 Mamba 集成到端到端神经网络中的示例。
此示例用于下文的生成脚本中。
## 预训练模型
预训练模型已上传至
[Hugging Face](https://huggingface.co/state-spaces):`mamba-130m`,`mamba-370m`,
`mamba-790m`,`mamba-1.4b`,`mamba-2.8b`,`mamba2-130m`,`mamba2-370m`,
`mamba2-780m`,`mamba2-1.3b`,`mamba2-2.7b`,`transformerpp-2.7b`,`mamba2attn-2.7b`,在 Pile 上基于 300B token 训练,以及 `mamba-2.8b-slimpj`
(在 SlimPajama 数据集上基于 600B token 训练)。
这些模型将由下文的生成脚本自动下载。
这些模型是在 [Pile](https://huggingface.co/datasets/EleutherAI/pile) 上训练的,并遵循 GPT-3 描述的以及许多开源模型采用的标准模型维度:
| Parameters | Layers | Model dim. |
|------------|--------|------------|
| 130M | 24 | 768 |
| 370M | 48 | 1024 |
| 790M | 48 | 1536 |
| 1.4B | 48 | 2048 |
| 2.8B | 64 | 2560 |
(Mamba 的层数是相同大小 Transformer 的两倍,因为 Transformer 的每个“层”(MHA 块 + MLP 块)需要两个 Mamba 块。)
注意:这些是仅基于 300B token 训练的基础模型,没有任何形式的下游修改(如 instruction tuning 等)。
预期性能应与在类似数据上训练的其他架构相当或更好,但无法匹配更大或经过微调的模型。
## 评估
要运行模型的零样本评估(对应于论文中的表 3),
我们使用
[lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness)
库。
1. 通过 `pip install lm-eval==0.4.2` 安装 `lm-evaluation-harness`。
2. 使用以下命令运行评估(更多文档请参见 [lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness/tree/big-refactor) 仓库):
```
lm_eval --model mamba_ssm --model_args pretrained=state-spaces/mamba-130m --tasks lambada_openai,hellaswag,piqa,arc_easy,arc_challenge,winogrande,openbookqa --device cuda --batch_size 256
python evals/lm_harness_eval.py --model hf --model_args pretrained=EleutherAI/pythia-160m --tasks lambada_openai,hellaswag,piqa,arc_easy,arc_challenge,winogrande --device cuda --batch_size 64
```
要复现博客文章中报告的 `mamba-2.8b-slimpj` 模型的结果:
```
lm_eval --model mamba_ssm --model_args pretrained=state-spaces/mamba-2.8b-slimpj --tasks boolq,piqa,hellaswag,winogrande,arc_easy,arc_challenge,openbookqa,race,truthfulqa_mc2 --device cuda --batch_size 256
lm_eval --model mamba_ssm --model_args pretrained=state-spaces/mamba-2.8b-slimpj --tasks mmlu --num_fewshot 5 --device cuda --batch_size 256
```
要在 Mamba-2 模型上运行评估,只需替换模型名称:
```
lm_eval --model mamba_ssm --model_args pretrained=state-spaces/mamba2-2.7b --tasks lambada_openai,hellaswag,piqa,arc_easy,arc_challenge,winogrande,openbookqa --device cuda --batch_size 256
lm_eval --model mamba_ssm --model_args pretrained=state-spaces/transformerpp-2.7b --tasks lambada_openai,hellaswag,piqa,arc_easy,arc_challenge,winogrande,openbookqa --device cuda --batch_size 256
lm_eval --model mamba_ssm --model_args pretrained=state-spaces/mamba2attn-2.7b --tasks lambada_openai,hellaswag,piqa,arc_easy,arc_challenge,winogrande,openbookqa --device cuda --batch_size 256
```
请注意,由于评估过程中的噪声,每个任务的结果可能与报告值有 0.1-0.3 的差异。
## 推理
脚本 [benchmarks/benchmark_generation_mamba_simple.py](benchmarks/benchmark_generation_mamba_simple.py)
1. 从 Hugging Face Hub 自动加载模型,
2. 生成用户指定提示的补全,
3. 对此生成的推理速度进行基准测试。
其他可配置选项包括 top-p(nucleus sampling)概率和 softmax 温度。
### 示例
要测试使用不同采样策略的生成延迟(例如 batch size = 1):
```
python benchmarks/benchmark_generation_mamba_simple.py --model-name "state-spaces/mamba-2.8b" --prompt "My cat wrote all this CUDA code for a new language model and" --topp 0.9 --temperature 0.7 --repetition-penalty 1.2
python benchmarks/benchmark_generation_mamba_simple.py --model-name "EleutherAI/pythia-2.8b" --prompt "My cat wrote all this CUDA code for a new language model and" --topp 0.9 --temperature 0.7 --repetition-penalty 1.2
python benchmarks/benchmark_generation_mamba_simple.py --model-name "state-spaces/mamba-2.8b" --prompt "My cat wrote all this CUDA code for a new language model and" --minp 0.05 --topk 0 --temperature 0.7 --repetition-penalty 1.2
```
要测试使用随机提示的生成吞吐量(例如大 batch size):
```
python benchmarks/benchmark_generation_mamba_simple.py --model-name "state-spaces/mamba-2.8b" --batch 64
python benchmarks/benchmark_generation_mamba_simple.py --model-name "EleutherAI/pythia-2.8b" --batch 64
```
对于 Mamba-2,只需更改模型名称:
```
python benchmarks/benchmark_generation_mamba_simple.py --model-name "state-spaces/mamba2-2.7b" --prompt "My cat wrote all this CUDA code for a new language model and" --topp 0.9 --temperature 0.7 --repetition-penalty 1.2
```
## 故障排除
### 精度
我们的模型使用 PyTorch [AMP](https://pytorch.org/docs/stable/amp.html) 进行混合精度训练。AMP 将模型参数保持在 float32 中,并在必要时转换为半精度。
另一方面,其他框架如 DeepSpeed 将参数存储在 float16 中,并在必要时向上转换(例如用于优化器累加)。
我们观察到主模型参数可能需要更高的精度,因为 SSM 对其循环动态很敏感。如果您遇到不稳定的情况,
作为第一步,请尝试使用以 fp32 存储参数的框架(如 AMP)。
### 初始化
模型的某些部分具有从 S4 模型的先前工作中继承的初始化方式。
例如(https://github.com/state-spaces/mamba/blob/f0affcf69f06d1d06cef018ff640bf080a11c421/mamba_ssm/modules/mamba_simple.py#L102),$\Delta$ 参数通过初始化其线性投影的偏置来具有目标范围。
但是,某些框架可能具有初始化后钩子(例如将 `nn.Linear` 模块中的所有偏置项设置为零)。
如果是这种情况,您可能需要添加自定义逻辑(例如这 [行](https://github.com/state-spaces/mamba/blob/f0affcf69f06d1d06cef018ff640bf080a11c421/mamba_ssm/modules/mamba_simple.py#L104) 在我们的训练器中关闭了重新初始化,但在任何其他框架中将无操作)
这是特定于训练框架的。
## AMD 显卡的额外前置条件
### 修补 ROCm
如果您使用的是 ROCm 6.0,请运行以下步骤以避免编译期间的错误。这对于 ROCm 6.1 及更高版本不是必需的。
1. 找到您的 ROCm 安装目录。通常位于 `/opt/rocm/`,但可能因您的安装而异。
2. 应用补丁。如果遇到权限问题,请使用 `sudo` 运行。
patch /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h < rocm_patch/rocm6_0.patch
## 引用
如果您使用此代码库,或认为我们的工作有价值,请引用 Mamba:
```
@article{mamba,
title={Mamba: Linear-Time Sequence Modeling with Selective State Spaces},
author={Gu, Albert and Dao, Tri},
journal={arXiv preprint arXiv:2312.00752},
year={2023}
}
@inproceedings{mamba2,
title={Transformers are {SSM}s: Generalized Models and Efficient Algorithms Through Structured State Space Duality},
author={Dao, Tri and Gu, Albert},
booktitle={International Conference on Machine Learning (ICML)},
year={2024}
}
@misc{lahoti2026mamba3improvedsequencemodeling,
title={Mamba-3: Improved Sequence Modeling using State Space Principles},
author={Aakash Lahoti and Kevin Y. Li and Berlin Chen and Caitlin Wang and Aviv Bick and J. Zico Kolter and Tri Dao and Albert Gu},
year={2026},
eprint={2603.15569},
archivePrefix={arXiv},
primaryClass={cs.LG},
url={https://arxiv.org/abs/2603.15569},
}
```
标签:CUDA加速, DLL 劫持, Mamba, PyTorch, SSM, Transformer替代方案, Vectored Exception Handling, 凭据扫描, 因果卷积, 大语言模型, 开源模型, 深度学习, 状态空间模型, 生成式AI, 硬件感知算法, 神经网络架构, 索引, 逆向工具, 选择性状态空间, 长序列建模, 高效推理