memo-ozdincer/RRFA
GitHub: memo-ozdincer/RRFA
基于LoRA电路断路器和三元组损失的LLM智能体提示注入防御框架,通过表示重路由实现工具调用安全保护。
Stars: 0 | Forks: 0
RRFA: Representation Rerouting for Agentic Safety
Internal defenses against prompt injection via LoRA circuit breakers and triplet loss
📕 Internal Research Report
🤗 HuggingFace Models
🤗 HuggingFace Datasets
训练 LoRA 适配器,使有害的内部表示与良性表示正交。当模型遇到提示注入时,其内部状态会自动重路由到安全行为——拒绝注入的工具调用,或者显著地,**恢复正确的预期动作**。将 Circuit Breakers 框架([Zou et al., 2024](https://arxiv.org/abs/2406.04313))从纯文本安全扩展到**智能体工具调用安全**,采用了新颖的三元组损失和可配置的损失掩码。
## 结果
| 配置 | 基线 ASR | CB ASR | 降低幅度 | 回归数 | AgentDojo 差异 | LLMail ASR |
|:---|:---:|:---:|:---:|:---:|:---:|:---:|
| **α=10.0, L{10,20}, cb_full_seq** | 83.7% | **8.2%** | **75.5pp** | **0** | **100%** | **5.0%** |
| α=5.0, L{10,20}, cb_full_seq | 86.7% | 11.2% | 75.5pp | 0 | 100% | 5.0% |
| α=15.0, L{10,20}, cb_full_seq | 84.7% | 14.3% | 70.4pp | 0 | 99% | 5.0% |
**关键发现:**
- **零回归**:CB 从未导致先前被阻止的攻击成功。
- **能力恢复**:在许多情况下,CB 模型不仅仅是拒绝——它会忽略注入并执行正确的用户意图(例如,在基线调用无关工具时创建正确的日历事件)。
- **跨数据集迁移**:在 Fujitsu tool-flip 数据上训练,能泛化到 AgentDojo(工具响应中的注入)和 LLMail(反向语义:正确 = 无工具调用)。
定性示例(点击展开)
**示例 1 — 能力恢复 (AgentDojo):**
**示例 2 — 强硬拒绝 (AgentDojo):**
**示例 3 — 优雅降级 (AgentDojo):**
## 方法
### 问题表述
对于具有工具 T、用户查询 q 和注入上下文 c 的智能体:
```
HARM(q, c) := (t_observed(q ⊕ c) ≠ t_expected(q)) ∧ injection_present(c)
```
危害 = 注入导致调用了不同的工具。二元、确定性,无需 LLM 评估。
### 三元组损失
我们通过结构化的三元组公式扩展了原始的 ReLU-cosine CB 损失。令 z̄_h = 批次有害质心:
```
L_benign = ReLU( d(h_frozen_b, h_model_b) - d(h_model_b, z̄_h) + m_b )
L_harmful = ReLU( d(h_model_h, z̄_h) - d(h_model_h, h_frozen_h) + m_h )
L_KL = KL( p_θ(·|x_b) ‖ p_θ₀(·|x_b) )
L_total = α_b · L_benign + β_h · L_harmful + γ · L_KL
```
**直觉:** L_benign 使良性表示比有害质心更接近冻结表示。L_harmful 将有害表示推向质心并远离冻结表示。L_KL 保留输出分布。
**距离函数**(每项可配置):d_L2 = ‖a-b‖₂, d_cos = 1-cos(a,b), d_mix = w₁·d_L2 + w₂·d_cos。默认值:d_mix,其中 w₁=w₂=0.5。
### 损失掩码策略 (LMP)
哪些 token 接收重路由损失:
| 策略 | Token | 备注 |
|:---|:---|:---|
| `assistant_only` | 所有 assistant 轮次 token | 标准 CB 方法 |
| `assistant_and_tool` | Assistant + 工具调用参数 | 包括参数 |
| **`cb_full_sequence`** | **整个序列** | **(最佳)** 学习检测上下文中的注入 |
| `tool_calls_only` | `<\|python_tag\|>{...}<\|eom_id\|>` | 最窄地关注工具调用 |
| `completion_only` | 最终 assistant 补全 | 旧版 CB 风格 |
**为什么 `cb_full_sequence` 胜出:** 将损失应用于注入 token 本身(而不仅仅是生成的工具调用)让模型学习注入*检测*。上下文表示在生成开始之前就被重塑,创造了一个早期的“绊网”。
### 内存优化
单模型架构:启用适配器 → θ(可训练);通过 `disable_adapter()` 禁用适配器 → θ₀(冻结参考)。相比加载两个模型副本,**VRAM 减半**。可在单张 80GB H100 上实现 MAX_SEQ_LENGTH=4096。
### 数据生成
**D_s(有害):** 强制性系统提示,T=0.7。仅保留攻击成功的样本。AgentDojo:`security==False` 轨迹。
**D_r(良性孪生):** 相同上下文,去除注入,防御性提示,T=0.3。教导模型*本应*执行的操作。
## 数据集
### Fujitsu B4 — Tool-Flip 编排攻击
| 属性 | 值 |
|:---|:---|
| **来源** | `data/fujitsu/orchestrator_attacks_combined_deduplicated.jsonl` |
| **大小** | 13K+ 攻击记录 |
| **攻击** | 注入将 `retrieve_multimodal_docs` 翻转为 `search_web` |
| **关键字段** | `benign_query`, `malicious_injection`, `combined_query`, `expected_tool`, `simulated_tool` |
| **处理** | ETL_A → 骨架轨迹 → vLLM DS/DR 生成 |
### AgentDojo — 多域注入
| 属性 | 值 |
|:---|:---|
| **来源** | `data/agent_dojo/agentdojo_spotlight_extract.jsonl` |
| **大小** | 194 条轨迹(银行、工作区、旅行等) |
| **攻击** | 注入在工具响应中的 `
` 标签 |
| **关键字段** | `metadata.security` (True=已防御), `metadata.injection_task_id` |
| **处理** | ETL_A → 完整轨迹 → 基于标签的划分 |
### LLMail-Inject — 邮件智能体攻击
| 属性 | 值 |
|:---|:---|
| **来源** | `data/llmail_inject/raw_submissions_phase1.jsonl` |
| **攻击** | 邮件注入指令以调用 `send_email`(数据窃取) |
| **语义** | 挑战忠实的单轮检索上下文;正确行为 = 无有害工具调用 |
| **指标** | 攻击 ASR(send_email 率),有用性(良性质量) |
| **处理** | ETL_A(基于场景的检索重建)→ 骨架轨迹 → vLLM DS/DR 生成 |
## 流水线架构
```
┌──────────────────────────────────────────────────────────────────────────┐
│ RRFA Training Pipeline │
├──────────┬──────────┬──────────┬──────────┬──────────┬─────────────────┤
│ ETL_A │ Generate │ Judge │ ETL_B │ Train │ Eval │
│ Raw→ │ Skeleton→│ Unlabel→ │ Trace→ │ Circuit │ ASR+Capability │
│ trace_v1 │ DS/DR │ Labeled │ Render+ │ Breaker │ +LLMail Attack │
│ │ via vLLM │ (opt.) │ LossMask │ Trainer │ +LLMail Useful │
└──────────┴──────────┴──────────┴──────────┴──────────┴─────────────────┘
```
**分层数据模型:**
| 层级 | 状态 | 描述 |
|:---|:---|:---|
| **A** | Raw | 原始数据集格式(Fujitsu JSONL, AgentDojo JSONL, LLMail JSONL) |
| **B1** | Skeleton | 标准化 `trace_v1` — 消息、标签、来源 — 无 assistant 补全 |
| **B2** | Complete | 带有来自 vLLM 的 DS(有害)和 DR(良性)补全的 `trace_v1` |
| **C** | Rendered | 通过 `apply_chat_template` 进行分词 + 来自 LMP 策略的逐 token 损失掩码 |
**规范模式** (`trace_v1`):`messages[]` (system/user/assistant/tool), `labels` (category, attack_present, attack_succeeded), `source` (dataset, subset, record_locator), `signal_hints` (tool expectations, injection spans), `completeness` (skeleton/complete), `tier` (B1/B2)。
## 快速开始
### 前置条件
```
pip install -r requirements.txt
# Requires: torch, transformers, peft, vllm, accelerate, tqdm
```
### 1. 缓存模型
```
bash slurm/cache_models.sh # Run on login node before submitting jobs
```
### 2. 运行超参数扫描
```
# Primary entry point — 自动生成缺失数据,然后运行 sweep
sbatch slurm/pipeline/sweep_hparams_simple.sbatch
# Custom sweep
ALPHAS=5.0,10.0,15.0,20.0 CB_LAYERS=10,20 \
sbatch slurm/pipeline/sweep_hparams_simple.sbatch
# Full pipeline (ETL_A → Generate → Judge → ETL_B → Split → Train → Eval)
sbatch slurm/pipeline/unified_pipeline.sbatch
```
### 3. 分析结果
```
# Summary table + ASCII charts + best runs
python scripts/visualize_sweep_results.py /path/to/sweep_dir
# Safety vs capability tradeoff with Pareto frontier
python scripts/plot_tradeoff.py --sweep-dir /path/to/sweep_dir
# Detailed sample viewer — 查看确切的 baseline vs CB responses
python scripts/visualize_sweep_results.py /path/to/sweep_dir --show-samples 10 --filter-success
# Compare specific samples across all runs
python scripts/visualize_sweep_results.py /path/to/sweep_dir --compare-samples 5 --compare-dataset fujitsu
# Export to CSV
python scripts/visualize_sweep_results.py /path/to/sweep_dir --csv results.csv
```
## 配置参考
### 扫描环境变量
| 变量 | 描述 | 默认值 |
|:---|:---|:---|
| `MODEL_ID` | 基础模型 ID | `meta-llama/Llama-3.1-8B-Instruct` |
| `PRESET` | 训练预设配置 | `llama-3.1-8b-instruct` |
| `ALPHAS` | Alpha 值(逗号分隔) | `5.0,10.0,15.0` |
| `CB_LAYERS` | 目标层(逗号分隔) | `10,20` |
| `LMP_POLICY` | 损失掩码策略 | `assistant_only` |
| `TOTAL_STEPS` | 每配置步数 | `200` |
| `BATCH_SIZE` | 单设备批次大小 | `1` |
| `GRAD_ACCUM` | 梯度累积 | `4` |
| `MAX_SEQ_LENGTH` | 最大序列长度 | `4096` |
| `LEARNING_RATE` | AdamW 学习率 | `5e-5` |
| `WARMUP_STEPS` | LR 预热步数 | `20` |
| `LOSS_MODE` | 损失公式 | `triplet_full` |
| `LOSS_WEIGHTING` | 系数调度 | `dual` |
| `LORA_R` | LoRA 秩 | `16` |
| `LORA_ALPHA` | LoRA 缩放 | `32` |
| `DTYPE` | 计算精度 | `bfloat16` |
| `NO_WANDB` | 禁用 W&B 日志 | `true` |
| `EVAL_LIMIT` | 最大评估样本数 | `100` |
| `USE_VLLM` | 使用 vLLM 生成 | `true` |
### 三元组损失超参数
| 参数 | 描述 | 默认值 | 备注 |
|:---|:---|:---|:---|
| `TRIPLET_ALPHA_BENIGN` | 良性三元组权重 (α_b) | `0.5` | 公式 4 系数 |
| `TRIPLET_BETA_HARMFUL` | 有害三元组权重 (β_h) | `0.4` | 公式 5 系数 |
| `TRIPLET_GAMMA_KL` | KL 散度权重 (γ) | `0.9` | 公式 6 系数 |
| `TRIPLET_MARGIN_BENIGN` | 良性边界 (m_b) | `500.0` | 铰接阈值 |
| `TRIPLET_MARGIN_HARMFUL` | 有害边界 (m_h) | `1500.0` | 铰接阈值 |
| `TRIPLET_BENIGN_POS_DISTANCE` | 良性正样本距离 | `dmix` | d_L2, d_cos, 或 d_mix |
| `TRIPLET_BENIGN_NEG_DISTANCE` | 良性负样本距离 | `dmix` | |
| `TRIPLET_HARMFUL_POS_DISTANCE` | 有害正样本距离 | `dmix` | |
| `TRIPLET_HARMFUL_NEG_DISTANCE` | 有害负样本距离 | `dmix` | |
| `TRIPLET_MIX_L2_WEIGHT` | d_mix 中的 L2 权重 | `0.5` | |
| `TRIPLET_MIX_COS_WEIGHT` | d_mix 中的 Cosine 权重 | `0.5` | |
### 阶段控制
通过环境变量跳过单个流水线阶段:
```
SKIP_ETL_A=true SKIP_GENERATE=true sbatch slurm/pipeline/unified_pipeline.sbatch # Reuse existing traces
SKIP_TRAIN=true SKIP_EVAL=true sbatch slurm/pipeline/unified_pipeline.sbatch # Data prep only
```
## 仓库结构
```
rrfa/
├── configs/
│ ├── dataset_config.yaml # Dataset types, label logic, split rules
│ ├── injection_patterns.json # Regex patterns for injection detection
│ ├── schemas/trace_v1.json # Canonical JSON schema
│ └── tool_schemas/
│ ├── b4_standard_v1.json # Fujitsu: retrieve_multimodal_docs + search_web
│ ├── llmail_inject_challenge_v2.json # LLMail: challenge-faithful, single endpoint
│ └── llmail_inject_v1.json # Legacy LLMail schema (kept for reproducibility)
│
├── data/
│ ├── fujitsu/ # Raw Fujitsu B4 attack records (13K+)
│ ├── agent_dojo/ # Raw AgentDojo traces (194)
│ └── llmail_inject/ # Raw LLMail-Inject submissions
│
├── docs/
│ ├── newschema_workingdocs.txt # PRIMARY: Pipeline architecture & schema docs
│ ├── adding_datasets.txt # Guide: integrating new datasets
│ ├── datapath.txt # Fujitsu B4 data path documentation
│ ├── overview.txt # Project overview & scratch structure
│ ├── dataset_field_mappings.yaml # Cross-dataset field mapping reference
│ └── FORMAT_AGNOSTIC_QUICK_REF.md # Format-agnostic rendering guide
│
├── paper/
│ ├── main.tex # Research paper (double-column LaTeX)
│ └── main_old.tex # Historical reference document
│
├── scripts/
│ ├── visualize_sweep_results.py # Sweep analysis: tables, samples, ASCII plots
│ ├── plot_tradeoff.py # Pareto frontier: safety vs capability
│ ├── plot_publication_figures.py # Publication-quality matplotlib figures
│ └── split_dataset.py # Split complete traces into CB/Retain sets
│
├── slurm/
│ ├── cache_models.sh # Cache HF models on login node
│ └── pipeline/
│ ├── sweep_hparams_simple.sbatch # ★ PRIMARY ENTRYPOINT — auto-gen + sweep
│ ├── sweep_hparams.sbatch # Core sweep logic (called by simple)
│ └── unified_pipeline.sbatch # Full 7-stage pipeline
│
├── src/
│ ├── data_generation/
│ │ └── generate_completions.py # vLLM DS/DR generation (modes: ds, dr, both)
│ │
│ ├── evaluation/
│ │ ├── eval.py # Eval: tool-flip ASR, LLMail attack/useful, AgentDojo
│ │ └── judge.py # LLM judge for unlabeled complete traces
│ │
│ ├── schemas/
│ │ ├── trace.py # trace_v1 Python dataclasses (Trace, Message, etc.)
│ │ ├── render.py # render_v1 types (spans, signals, alignment)
│ │ ├── lossmask.py # lossmask_v1 types (per-token mask arrays)
│ │ ├── registry.py # LMP policy registry & MWCS registry
│ │ └── tools/
│ │ ├── ETL_A.py # Raw → trace_v1 (Fujitsu, AgentDojo, LLMail)
│ │ └── ETL_B.py # trace_v1 → render_v1 + lossmask_v1
│ │
│ ├── training/
│ │ ├── trainer.py # CircuitBreakerTrainer: single-model, DDP-safe
│ │ ├── train_schema.py # Schema-aware entry point (delegates to trainer)
│ │ ├── losses.py # triplet_full, reroute_relu_cos, retain_l2, KL, distances
│ │ ├── config.py # CircuitBreakerConfig dataclass + presets
│ │ └── hf_utils.py # HF token resolution, offline model path
│ │
│ └── utils/
│ └── wandb_logging.py # W&B init, artifact logging, metadata
│
├── tests/ # Unit tests
├── requirements.txt # Python dependencies
└── README.md # This file
```
### 扫描输出结构
```
hparam_sweep_YYYYMMDD_HHMMSS/
├── summary.csv # All runs: alpha, layers, policy, all metrics
├── sweep.log # Master sweep log
├── a{X}_l{Y}_{policy}/ # Per-configuration directory
│ ├── eval/
│ │ ├── fujitsu_eval.json # {baseline: {tool_flip_asr: ...}, cb_model: {...}, delta: ...}
│ │ ├── fujitsu_eval.paired_outputs.jsonl
│ │ ├── llmail_eval.json # {baseline: {llmail_attack: ...}, cb_model: {llmail_usefulness: ...}}
│ │ └── agentdojo_eval.json # {output_comparison: {difference_rate: ...}}
│ ├── model/
│ │ └── final/ # LoRA adapter weights (adapter_config.json + weights)
│ └── etl_b.log
└── plots/ # Auto-generated tradeoff plots (if matplotlib available)
```
## 关键实现细节
### `src/training/trainer.py` — CircuitBreakerTrainer
- **单次前向传递**(DDP 安全):有害 + 良性沿批次维度连接,随后拆分。避免 DDP + 梯度检查点下的重入反向传播。
- **表示提取**通过 `output_hidden_states=True`(首选)或前向钩子(旧版)。选定层返回为 `Dict[int, Tensor]`。
- **池化表示**:在三元组计算之前通过损失掩码加权的逐样本均值池化。
- **双系数调度**:`cs(t)`(重路由,1→0)和 `cr(t)`(保留,0→1)用于旧版模式。三元组模式使用固定的 α_b/β_h/γ。
- **补全掩码验证**在训练开始时:验证 `<|python_tag|>` token 被损失掩码覆盖。
### `src/training/losses.py` — 损失函数
| 函数 | 用途 |
|:---|:---|
| `triplet_full_loss()` | 主损失:良性铰链 + 有害铰链 + KL。返回总计 + 每组件指标。 |
| `reroute_loss_relu_cos()` | 原始 CB:目标层上的 ReLU(cos_sim),掩码均值。 |
| `retain_loss_l2()` | 原始 CB:目标层上的 L2 距离,掩码均值。 |
| `kl_divergence_loss()` | 带温度缩 Token 级 KL 散度,已掩码。 |
| `pair_distance()` | 可配置:d_L2, d_cos, d_mix, d_null。由三元组项使用。 |
| `pooled_representations()` | 带 token 掩码的目标层上隐藏状态的均值池化。 |
| `random_reroute_loss()` | 旧版模式:推向随机方向。 |
| `retain_ce_loss()` | 旧版:良性输出上的交叉熵。 |
### `src/evaluation/eval.py` — 评估
- **Fujitsu**:`evaluate_tool_flip_asr()` — 筛选 `expected_tool ≠ simulated_tool` 的样本,比较基线与 CB 的工具选择。
- **LLMail**:`evaluate_llmail_attack()` — 默认情况下将检索样本的响应分类为 attack_success / refusal / other_tool。`evaluate_llmail_usefulness()` 在同一筛选子集上运行。
- **AgentDojo**:`output_comparison` — 基线与 CB 响应之间的差异率。
- **`--merge-adapter`**:在评估前将 LoRA 合并到基础权重中以加快推理。
- **配对输出**:每次评估写入包含每样本基线/CB 响应的 `.paired_outputs.jsonl` 以供详细分析。
### `src/schemas/tools/ETL_B.py` — 渲染与损失掩码
- 通过 `apply_chat_template` 渲染轨迹,并检测 Llama 3.1 格式。
- 计算逐 token **跨度注释**:`AssistantSpan`, `ToolCallSpan`, `InjectionSpan`, `ActionCommitment`。
- 应用 LMP 策略以生成与 token ID 对齐的二进制损失掩码数组。
- 支持格式系列:`llama_python_tag`, `openai_json`, `anthropic_xml`, `generic_json`。
## 参考文献
- Zou, A., et al. (2024). *Improving Alignment and Robustness with Circuit Breakers.* [arXiv:2406.04313](https://arxiv.org/abs/2406.04313)
- Debenedetti, E., et al. (2024). *AgentDojo: A Dynamic Environment to Evaluate Attacks and Defenses for LLM Agents.* [arXiv:2406.13352](https://arxiv.org/abs/2406.13352)
为 H100 SXM 80GB 上的 Llama-3.1-8B-Instruct 构建。通过 Alliance Canada 上的 SLURM 管理。
标签:AI Agent安全, HuggingFace, IaC 扫描, Llama-3, LLM内生防御, LoRA适配器, RRFA, Trivy, 三元组损失, 凭据扫描, 大语言模型安全, 工具调用安全, 提示词注入防御, 机密管理, 模型微调, 电路断路器, 系统调用监控, 表示重路由, 逆向工具