隐私计算FATE-多分类神经网络算法测试

打印 上一主题 下一主题

主题 859|帖子 859|积分 2577


一、说明

本文分享基于 Fate 使用 横向联邦 神经网络算法 对 多分类 的数据进行 模型训练,并使用该模型对数据进行 多分类预测。


  • 二分类算法:是指待预测的 label 标签的取值只有两种;直白来讲就是每个实例的可能类别只有两种(0 或者 1),例如性别只有 或者 ;此时的分类算法其实是在构建一个分类线将数据划分为两个类别。
  • 多分类算法:是指待预测的 label 标签的取值可能有多种情况,例如个人爱好可能有 篮球足球电影 等等多种类型。常见算法:Softmax、SVM、KNN、决策树。
关于 Fate 的核心概念、单机部署、训练以及预测请参考以下相关文章:
 
二、准备训练数据

上传到 Fate 里的数据有两个字段名必需是规定的,分别是主键为 id 字段和分类字段为 y 字段,y 字段就是所谓的待预测的 label 标签;其他的特征字段(属性)可任意填写,例如下面例子中的 x0 - x9
例如有一条用户数据为: 收入 : 10000,负债 : 5000,是否有还款能力 : 1 ;数据中的 收入 和 负债 就是特征字段,而 是否有还款能力 就是分类字段。
本文只描述关键部分,关于详细的模型训练步骤,请查看文章《隐私计算FATE-模型训练
2.1. guest端

10条数据,包含1个分类字段 y 和 10 个标签字段 x0 - x9

y 值有 0、1、2、3 四个分类
上传到 Fate 中,表名为 muti_breast_homo_guest 命名空间为 experiment
 
2.2. host端

10条数据,字段与 guest 端一样,但是内容不一样

上传到 Fate 中,表名为 muti_breast_homo_host 命名空间为 experiment
 
三、执行训练任务

3.1. 准备dsl文件

创建文件 homo_nn_dsl.json 内容如下 :
  1. {
  2.     "components": {
  3.         "reader_0": {
  4.             "module": "Reader",
  5.             "output": {
  6.                 "data": [
  7.                     "data"
  8.                 ]
  9.             }
  10.         },
  11.         "data_transform_0": {
  12.             "module": "DataTransform",
  13.             "input": {
  14.                 "data": {
  15.                     "data": [
  16.                         "reader_0.data"
  17.                     ]
  18.                 }
  19.             },
  20.             "output": {
  21.                 "data": [
  22.                     "data"
  23.                 ],
  24.                 "model": [
  25.                     "model"
  26.                 ]
  27.             }
  28.         },
  29.         "homo_nn_0": {
  30.             "module": "HomoNN",
  31.             "input": {
  32.                 "data": {
  33.                     "train_data": [
  34.                         "data_transform_0.data"
  35.                     ]
  36.                 }
  37.             },
  38.             "output": {
  39.                 "data": [
  40.                     "data"
  41.                 ],
  42.                 "model": [
  43.                     "model"
  44.                 ]
  45.             }
  46.         }
  47.     }
  48. }
复制代码
 
3.2. 准备conf文件

创建文件 homo_nn_multi_label_conf.json 内容如下 :
  1. {
  2.     "dsl_version": 2,
  3.     "initiator": {
  4.         "role": "guest",
  5.         "party_id": 9999
  6.     },
  7.     "role": {
  8.         "arbiter": [
  9.             10000
  10.         ],
  11.         "host": [
  12.             10000
  13.         ],
  14.         "guest": [
  15.             9999
  16.         ]
  17.     },
  18.     "component_parameters": {
  19.         "common": {
  20.             "data_transform_0": {
  21.                 "with_label": true
  22.             },
  23.             "homo_nn_0": {
  24.                 "encode_label": true,
  25.                 "max_iter": 15,
  26.                 "batch_size": -1,
  27.                 "early_stop": {
  28.                     "early_stop": "diff",
  29.                     "eps": 0.0001
  30.                 },
  31.                 "optimizer": {
  32.                     "learning_rate": 0.05,
  33.                     "decay": 0.0,
  34.                     "beta_1": 0.9,
  35.                     "beta_2": 0.999,
  36.                     "epsilon": 1e-07,
  37.                     "amsgrad": false,
  38.                     "optimizer": "Adam"
  39.                 },
  40.                 "loss": "categorical_crossentropy",
  41.                 "metrics": [
  42.                     "accuracy"
  43.                 ],
  44.                 "nn_define": {
  45.                     "class_name": "Sequential",
  46.                     "config": {
  47.                         "name": "sequential",
  48.                         "layers": [
  49.                             {
  50.                                 "class_name": "Dense",
  51.                                 "config": {
  52.                                     "name": "dense",
  53.                                     "trainable": true,
  54.                                     "batch_input_shape": [
  55.                                         null,
  56.                                         18
  57.                                     ],
  58.                                     "dtype": "float32",
  59.                                     "units": 5,
  60.                                     "activation": "relu",
  61.                                     "use_bias": true,
  62.                                     "kernel_initializer": {
  63.                                         "class_name": "GlorotUniform",
  64.                                         "config": {
  65.                                             "seed": null,
  66.                                             "dtype": "float32"
  67.                                         }
  68.                                     },
  69.                                     "bias_initializer": {
  70.                                         "class_name": "Zeros",
  71.                                         "config": {
  72.                                             "dtype": "float32"
  73.                                         }
  74.                                     },
  75.                                     "kernel_regularizer": null,
  76.                                     "bias_regularizer": null,
  77.                                     "activity_regularizer": null,
  78.                                     "kernel_constraint": null,
  79.                                     "bias_constraint": null
  80.                                 }
  81.                             },
  82.                             {
  83.                                 "class_name": "Dense",
  84.                                 "config": {
  85.                                     "name": "dense_1",
  86.                                     "trainable": true,
  87.                                     "dtype": "float32",
  88.                                     "units": 4,
  89.                                     "activation": "sigmoid",
  90.                                     "use_bias": true,
  91.                                     "kernel_initializer": {
  92.                                         "class_name": "GlorotUniform",
  93.                                         "config": {
  94.                                             "seed": null,
  95.                                             "dtype": "float32"
  96.                                         }
  97.                                     },
  98.                                     "bias_initializer": {
  99.                                         "class_name": "Zeros",
  100.                                         "config": {
  101.                                             "dtype": "float32"
  102.                                         }
  103.                                     },
  104.                                     "kernel_regularizer": null,
  105.                                     "bias_regularizer": null,
  106.                                     "activity_regularizer": null,
  107.                                     "kernel_constraint": null,
  108.                                     "bias_constraint": null
  109.                                 }
  110.                             }
  111.                         ]
  112.                     },
  113.                     "keras_version": "2.2.4-tf",
  114.                     "backend": "tensorflow"
  115.                 },
  116.                 "config_type": "keras"
  117.             }
  118.         },
  119.         "role": {
  120.             "host": {
  121.                 "0": {
  122.                     "reader_0": {
  123.                         "table": {
  124.                             "name": "muti_breast_homo_host",
  125.                             "namespace": "experiment"
  126.                         }
  127.                     }
  128.                 }
  129.             },
  130.             "guest": {
  131.                 "0": {
  132.                     "reader_0": {
  133.                         "table": {
  134.                             "name": "muti_breast_homo_guest",
  135.                             "namespace": "experiment"
  136.                         }
  137.                     }
  138.                 }
  139.             }
  140.         }
  141.     }
  142. }
复制代码
注意 reader_0 组件的表名和命名空间需与上传数据时配置的一致。
 
3.3. 提交任务

执行以下命令:
  1. flow job submit -d homo_nn_dsl.json -c homo_nn_multi_label_conf.json
复制代码
执行成功后,查看 dashboard 显示:

 
四、准备预测数据

与前面训练的数据字段一样,但是内容不一样,y 值全为 0
4.1. guest端


上传到 Fate 中,表名为 predict_muti_breast_homo_guest 命名空间为 experiment
 
4.2. host端


上传到 Fate 中,表名为 predict_muti_breast_homo_host 命名空间为 experiment
 
五、准备预测配置

本文只描述关键部分,关于详细的预测步骤,请查看文章《隐私计算FATE-离线预测
创建文件 homo_nn_multi_label_predict.json 内容如下 :
  1. {
  2.     "dsl_version": 2,
  3.     "initiator": {
  4.         "role": "guest",
  5.         "party_id": 9999
  6.     },
  7.     "role": {
  8.         "arbiter": [
  9.             10000
  10.         ],
  11.         "host": [
  12.             10000
  13.         ],
  14.         "guest": [
  15.             9999
  16.         ]
  17.     },
  18.     "job_parameters": {
  19.         "common": {
  20.             "model_id": "arbiter-10000#guest-9999#host-10000#model",
  21.             "model_version": "202207061504081543620",
  22.             "job_type": "predict"
  23.         }
  24.     },
  25.     "component_parameters": {
  26.         "role": {
  27.             "guest": {
  28.                 "0": {
  29.                     "reader_0": {
  30.                         "table": {
  31.                             "name": "predict_muti_breast_homo_guest",
  32.                             "namespace": "experiment"
  33.                         }
  34.                     }
  35.                 }
  36.             },
  37.             "host": {
  38.                 "0": {
  39.                     "reader_0": {
  40.                         "table": {
  41.                             "name": "predict_muti_breast_homo_host",
  42.                             "namespace": "experiment"
  43.                         }
  44.                     }
  45.                 }
  46.             }
  47.         }
  48.     }
  49. }
复制代码
注意以下两点:

  • model_id 和 model_version 需修改为模型部署后的版本号。
  • reader_0 组件的表名和命名空间需与上传数据时配置的一致。
 
六、执行预测任务

执行以下命令:
  1. flow job submit -c homo_nn_multi_label_predict.json
复制代码
执行成功后,查看 homo_nn_0 组件的数据输出:

可以看到算法输出的预测结果。
 
扫码关注有惊喜!


免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!

本帖子中包含更多资源

您需要 登录 才可以下载或查看,没有账号?立即注册

x
回复

使用道具 举报

0 个回复

倒序浏览

快速回复

您需要登录后才可以回帖 登录 or 立即注册

本版积分规则

立聪堂德州十三局店

金牌会员
这个人很懒什么都没写!
快速回复 返回顶部 返回列表