theSiddhantPatel/cti-ml-lab
GitHub: theSiddhantPatel/cti-ml-lab
基于Flower和PyTorch构建的联邦学习实验平台,使多个客户端能够在不共享原始数据的情况下协同训练全局模型。
Stars: 1 | Forks: 0
# 使用 Flower 和 PyTorch 构建联邦学习系统
## 项目概述
本项目展示了如何使用 **Flower** 和 **PyTorch** 构建一个简单的联邦学习系统。多个客户端在各自的本地 MNIST 数据上训练同一个模型,同时中央服务器收集并聚合模型更新,以创建更好的全局模型。
其核心思想是**在不将原始客户端数据发送到服务器的情况下**训练共享模型。这使该项目成为隐私感知分布式机器学习的一个适合初学者的优秀示例。
## 什么是联邦学习?
联邦学习是一种机器学习方法,许多客户端在本地根据自己的数据训练模型。
客户端不共享数据集,而是仅将模型更新发送给服务器,服务器将这些更新组合起来以改进全局模型。
## 功能特性
- 多客户端联邦学习设置
- 用于服务器端聚合的自定义 Flower 策略
- 跨训练轮次的全局模型聚合
- 每轮后的准确率追踪
- 使用 `matplotlib` 进行准确率可视化
- 每轮后保存全局模型
- 基于 MNIST 的示例,易于理解和运行
## 技术栈
- Python
- Flower
- PyTorch
- Torchvision
- NumPy
- Matplotlib
## 项目结构
```
cti-ml-lab/
├── federated_cti/
│ ├── main.py # Client code for local training and evaluation
│ ├── server.py # Server setup, custom strategy, and aggregation logic
│ ├── run_clients.py # Starts multiple clients automatically
│ └── plot_metrics.py # Plots accuracy across training rounds
├── requirements.txt # Project dependencies
└── README.md # Project documentation
```
## 项目工作原理
1. 服务器启动并等待客户端连接。
2. 每个客户端加载自己的 MNIST 训练数据部分。
3. 客户端使用 PyTorch 在本地训练模型。
4. 服务器从所有客户端收集更新后的模型参数。
5. 服务器将这些更新聚合为一个全局模型。
6. 此过程重复多轮。
7. 每轮结束后,保存全局模型并追踪准确率。
## 安装说明
### 1. 克隆仓库
```
git clone
cd cti-ml-lab
```
### 2. 创建虚拟环境
```
python -m venv venv
```
### 3. 激活虚拟环境
在 Windows 上:
```
venv\Scripts\activate
```
在 macOS/Linux 上:
```
source venv/bin/activate
```
### 4. 安装依赖
```
pip install -r requirements.txt
```
## 如何运行项目
打开终端并进入项目文件夹:
```
cd federated_cti
```
### 步骤 1:启动服务器
```
python server.py
```
服务器将在以下地址启动:
```
127.0.0.1:8081
```
### 步骤 2:启动客户端
在同一文件夹中打开另一个终端并运行:
```
python run_clients.py
```
此脚本会启动 3 个客户端。每个客户端在 MNIST 数据集的不同部分上进行训练。
### 步骤 3:绘制准确率图表
训练完成后,运行:
```
python plot_metrics.py
```
这将根据保存的训练历史生成准确率图表。
## 示例输出
在训练期间,您可能会看到类似以下的输出:
```
Round 1 Global Accuracy: 0.7821
Round 2 Global Accuracy: 0.8467
Round 3 Global Accuracy: 0.8794
```
这表明随着服务器整合来自所有客户端的学习成果,全局模型随时间推移而改进。
## 训练后生成的文件
- `accuracy_history.json` - 存储每轮的准确率数值
- `global_model_round_1.pth` - 第 1 轮后保存的全局模型
- `global_model_round_2.pth` - 第 2 轮后保存的全局模型
- `global_model_round_3.pth` - 第 3 轮后保存的全局模型
- `accuracy_plot.png` - 由 `plot_metrics.py` 创建的准确率图表
## 未来改进方向
- 增加对更多客户端和灵活客户端选择的支持
- 使用更深的神经网络以获得更好的性能
- 增加对非独立同分布(non-IID)数据的支持
- 同时追踪损失值和准确率
- 以更详细的格式保存训练日志
- 添加配置文件以便于实验设置
- 支持 GPU 训练以加快本地更新速度
## 为什么这个项目有用
该项目是学习联邦学习实际运作方式的简单起点。它有助于初学者理解本地训练、服务器聚合和基于轮次的学习如何在实际实现中协同工作。
## 总结
该联邦学习系统演示了如何将 Flower 和 PyTorch 结合使用,在保持客户端数据在本地的同时训练共享模型。它规模小、实用,且易于扩展以进行更高级的联邦学习实验。
标签:Apex, Flower, MNIST, Mutation, Python, PyTorch, Torchvision, 人工智能, 凭据扫描, 分布式训练, 可视化, 后端开发, 威胁情报, 开发者工具, 异常检测, 数据隐私, 数据预处理, 无后门, 机器学习, 模型聚合, 深度学习, 用户模式Hook绕过, 网络安全, 网络安全, 联邦学习, 逆向工具, 隐私保护, 隐私保护, 隐私计算