联邦学习开源平台Flower:在移动装备上运行联邦学习

打印 上一主题 下一主题

主题 868|帖子 868|积分 2604

Flower官网:https://flower.ai/
官网doc:https://flower.ai/docs/
Flower github:https://github.com/adap/flower
为什么用Flower

Flower可以将预先写好的会合式机器学习代码以联邦学习的方式运行(只需要少量的修改)。并且他可以在windows情况下模拟联邦学习场景,非常适合实验。
Flower开源了很多联邦学习的基线算法的示例,可以轻松入门。
本文主要将一下如何使用Flower,通过对官网给出的quickstart示例,联合其他csdn,给出运行Flower的步调以及在android上运行联邦学习的示例。
Flower安装

Flower安装需要至少python 3.8版本以上,保举使用python 3.10 及以上,本文将使用python 3.8 运行下列的示例。
官网安装教程
创建自己的假造情况:conda

通过anaconda创建假造情况:
  1. conda create -n flwr python=3.8
复制代码
激活情况
  1. conda activate flwr
复制代码
直接安装(稳定版):
  1. pip install flwr
复制代码
注意:直接pip安装的是flower的稳定版本,与官网的版本可能不同。
查看Flower的安装版本
  1. python -c "import flwr;print(flwr.__version__)"
复制代码
我安装的1.10.0版本。
背面的示例请移至对应版本的doc官网,详细版本的官网在doc官网的左下角的Versions可以找到
Flower V-1.10.0的doc地址
Flower实例运行(quickstart pytorch和quickstart tensorflow)

quickstart pytorch

代码参考csdn
编写代码

创建两个文件:client.py 和 server.py
client.py:编写客户端运行文件,主要包括传统的机器学习流程代码和Flower客户端类实现。详细如下
  1. # 传统的机器学习流程代码
  2. from collections import OrderedDict
  3. import flwr as fl
  4. import torch
  5. import torch.nn as nn
  6. import torch.nn.functional as F
  7. from torch.utils.data import DataLoader
  8. from torchvision.datasets import CIFAR10
  9. from torchvision.transforms import Compose, Normalize, ToTensor
  10. from tqdm import tqdm
  11. DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  12. # 定义神经网络模型
  13. class Net(nn.Module):
  14.     def __init__(self) -> None:
  15.         super(Net, self).__init__()
  16.         self.conv1 = nn.Conv2d(3, 6, 5)
  17.         self.pool = nn.MaxPool2d(2, 2)
  18.         self.conv2 = nn.Conv2d(6, 16, 5)
  19.         self.fc1 = nn.Linear(16 * 5 * 5, 120)
  20.         self.fc2 = nn.Linear(120, 84)
  21.         self.fc3 = nn.Linear(84, 10)
  22.     def forward(self, x: torch.Tensor) -> torch.Tensor:
  23.         x = self.pool(F.relu(self.conv1(x)))
  24.         x = self.pool(F.relu(self.conv2(x)))
  25.         x = x.view(-1, 16 * 5 * 5)
  26.         x = F.relu(self.fc1(x))
  27.         x = F.relu(self.fc2(x))
  28.         return self.fc3(x)
  29. # 定义模型训练流程
  30. def train(net, trainloader, epochs):
  31.     criterion = torch.nn.CrossEntropyLoss()
  32.     optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
  33.     for _ in range(epochs):
  34.         for images, labels in tqdm(trainloader):
  35.             optimizer.zero_grad()
  36.             criterion(net(images.to(DEVICE)), labels.to(DEVICE)).backward()
  37.             optimizer.step()
  38. # 定义模型推理流程
  39. def test(net, testloader):
  40.     criterion = torch.nn.CrossEntropyLoss()
  41.     correct, loss = 0, 0.0
  42.     with torch.no_grad():
  43.         for images, labels in tqdm(testloader):
  44.             outputs = net(images.to(DEVICE))
  45.             labels = labels.to(DEVICE)
  46.             loss += criterion(outputs, labels).item()
  47.             correct += (torch.max(outputs.data, 1)[1] == labels).sum().item()
  48.     accuracy = correct / len(testloader.dataset)
  49.     return loss, accuracy
  50. # 定义数据集的获取
  51. def load_data():
  52.     """Load CIFAR-10 (training and test set)."""
  53.     trf = Compose([ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
  54.     trainset = CIFAR10("./data", train=True, download=True, transform=trf)
  55.     testset = CIFAR10("./data", train=False, download=True, transform=trf)
  56.     return DataLoader(trainset, batch_size=32, shuffle=True), DataLoader(testset)
  57. # 生成模型对象,实际获取训练与测试数据集
  58. net = Net().to(DEVICE)
  59. trainloader, testloader = load_data()
复制代码
  1. # 客户端类实现
  2. class FlowerClient(fl.client.NumPyClient):
  3.     # 获取本地模型对应的参数
  4.     def get_parameters(self, config):
  5.         return [val.cpu().numpy() for _, val in net.state_dict().items()]
  6.     # 接收模型参数,并更新本地模型
  7.     def set_parameters(self, parameters):
  8.         params_dict = zip(net.state_dict().keys(), parameters)
  9.         state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
  10.         net.load_state_dict(state_dict, strict=True)
  11.     # 本地模型训练,会先调用 set_parameters() 基于收到的全局模型参数更新本地模型
  12.     def fit(self, parameters, config):
  13.         self.set_parameters(parameters)
  14.         train(net, trainloader, epochs=1)
  15.         return self.get_parameters(config={}), len(trainloader.dataset), {}
  16.     # 基于测试数据集进行测试
  17.     def evaluate(self, parameters, config):
  18.         self.set_parameters(parameters)
  19.         loss, accuracy = test(net, testloader)
  20.         return loss, len(testloader.dataset), {"accuracy": accuracy}
  21. # 启动 Flower 客户端
  22. fl.client.start_numpy_client(
  23.     server_address="127.0.0.1:8080",
  24.     client=FlowerClient(),
  25. )
复制代码
客户端类会继续flwr的NumPyClient类,当服务器选择一个特定的客户端举行训练时,他会通过网络发送训练指令。
这里服务器与客户端在同一个主机上运行,因此server_address就可以用本地回环地址127.0.0.1,FL服务器默认端口使用8080。如果服务器和客户端不是同一台主机,则可以使用真实的IP地址。
server.py:
  1. from typing import List, Tuple
  2. import flwr as fl
  3. from flwr.common import Metrics
  4. # 定义指标聚合方法
  5. def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics:
  6.     accuracies = [num_examples * m["accuracy"] for num_examples, m in metrics]
  7.     examples = [num_examples for num_examples, _ in metrics]
  8.     return {"accuracy": sum(accuracies) / sum(examples)}
  9. # 定义模型聚合策略
  10. strategy = fl.server.strategy.FedAvg(evaluate_metrics_aggregation_fn=weighted_average)
  11. # 启动 Flower 服务端
  12. fl.server.start_server(
  13.     server_address="0.0.0.0:8080",                # 服务器地址
  14.     config=fl.server.ServerConfig(num_rounds=3),
  15.     strategy=strategy,
  16. )
复制代码
在服务器代码中界说聚合策略。
运行python文件

先进入文件地点的文件夹,并激活假造情况。
直接运行server.py
  1. python server.py
复制代码
重新打开一个终端运行client.py
  1. python client.py
复制代码
服务器默认最少的客户端数量是2,所以要运行两个client才能举行联邦学习训练。
(重新打开一个终端,再运行一次client.py / 一共三个终端)
运行效果

server

clinet1

client2

quickstart tensorflow

与quickstart pytorch雷同,创建两个python文件:client.py和server.py,代码参考官网给出的代码,但需要轻微调解一下。
client.py:需要注意的是server的地址,可以用上面谁人地址
  1. import flwr as fl
  2. import tensorflow as tf
  3. # 加载数据
  4. (x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
  5. # 加载模型,10分类模型MobilNetV2
  6. model = tf.keras.applications.MobileNetV2((32, 32, 3), classes=10, weights=None)
  7. model.compile("adam", "sparse_categorical_crossentropy", metrics=["accuracy"])
  8. # 定义client类,也就是flower客户端
  9. class CifarClient(fl.client.NumPyClient):
  10.     def get_parameters(self, config):
  11.         return model.get_weights()
  12.     def fit(self, parameters, config):
  13.         model.set_weights(parameters)
  14.         model.fit(x_train, y_train, epochs=1, batch_size=32, steps_per_epoch=3)
  15.         return model.get_weights(), len(x_train), {}
  16.     def evaluate(self, parameters, config):
  17.         model.set_weights(parameters)
  18.         loss, accuracy = model.evaluate(x_test, y_test)
  19.         return loss, len(x_test), {"accuracy": float(accuracy)}
  20.    
  21. fl.client.start_client(server_address="127.0.0.1:8080", client=CifarClient().to_client()) # 启动flower客户端
复制代码
server.py:注意server的地址
  1. import flwr as fl
  2. if __name__ == '__main__':
  3.     fl.server.start_server(server_address="0.0.0.0:8080", config=fl.server.ServerConfig(num_rounds=3))     # 启动flower服务器
复制代码
如果按照官网里的谁人地址,我运行会报错,显示连接不到地址

运行效果

server

client1

client2

Flower实例运行(quickstart android)

github官网安卓示例
在安卓客户端上使用TFLite举行CIFAR10的联邦学习,将CIFAR-10数据集随机分配给10个客户端,服务器用python运行,客户端运行在安卓上。背景线程是通过安卓的WorkManager库建立的,因此它可以在8到13的安卓版本上运行。
起首需要有安卓假造机,可以下载android studio,本文将不在赘述。
下载源码:从github上下载example里的源码。源码里有示例的apk文件,需要先把这个apk文件下载下来(注册个账号,直接下载到本地)
https://www.dropbox.com/s/ii0vwrjrpupifiv/flower-client.apk?dl=0
源码中需要最少四个android装备才能运行联邦学习,当然,这个可以在server.py文件中更改(我的主机运行不了那么多假造装备,所以就测试2个),详细如下:
  1.                 min_fit_clients=2,        # 根据自己需要更改最少客户端
  2.                 min_evaluate_clients=2,
  3.                 min_available_clients=2,
复制代码
创建两个Android studio假造装备,打开这两个装备,将刚下载的apk文件拖拽到假造机里,假造机会自动下载apk应用。下载完会有一个flower的应用,点开如下图所示:

在假造机app里输入client id,server IP / port
server.py代码里的id是0.0.0.0:8080,这里就输入真实的ip就行了,Port就是8080

激活假造情况并下载依赖项:
  1. pip install -r requirements.txt
复制代码
运行服务器:
  1. python server.py
复制代码
依次点击假造装备app里的三个黄色按钮
运行效果:
server

client1

client2

我目前只能运行apk已有的示例,examples里的android项目没有搭建成功
在Android studio中构建案例中的android的情况时,报错了:无法找到依赖项TFLite;位置在app文件夹下的build.gradle里


免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。

本帖子中包含更多资源

您需要 登录 才可以下载或查看,没有账号?立即注册

x
回复

使用道具 举报

0 个回复

倒序浏览

快速回复

您需要登录后才可以回帖 登录 or 立即注册

本版积分规则

天津储鑫盛钢材现货供应商

金牌会员
这个人很懒什么都没写!
快速回复 返回顶部 返回列表