深度学习示例2-多输入多输出的神经网络模子

[复制链接]
发表于 2026-1-26 14:43:26 | 显示全部楼层 |阅读模式

  一、代码示例

  1. from tensorflow import keras
  2. from tensorflow.keras import layers
  3. import numpy as np
  4. # 定义 多输入 多输出的模型
  5. vocabulary_size = 1000
  6. num_tags = 100
  7. num_departments = 4
  8. title = keras.Input(shape=(vocabulary_size,), name = "title")
  9. text_body = keras.Input(shape=(vocabulary_size,), name = "text_body")
  10. tags = keras.Input(shape=(num_tags,), name = "tags")
  11. features = layers.Concatenate() ([title, text_body, tags])
  12. features = layers.Dense(64, activation = "relu") (features)
  13. priority = layers.Dense(1, activation = "sigmoid", name = "priority") (features)
  14. department = layers.Dense(num_departments, activation = "softmax", name = "department") (features)
  15. model = keras.Model(inputs=[title, text_body, tags], outputs=[priority, department])
  16. # 训练多输入 多输出的模型
  17. num_samples = 1280
  18. title_data = np.random.randint(0, 2, size=(num_samples, vocabulary_size))
  19. text_body_data = np.random.randint(0, 2, size=(num_samples, vocabulary_size))
  20. tags_data = np.random.randint(0, 2, size=(num_samples, num_tags))
  21. priority_data = np.random.random(size=(num_samples, 1))
  22. department_data = np.random.randint(0, 2, size=(num_samples, num_departments))
  23. model.compile(optimizer="rmsprop", loss=["mean_squared_error", "categorical_crossentropy"], metrics=[["mean_absolute_error"], ["accuracy"]])
  24. model.fit([title_data, text_body_data, tags_data], [priority_data, department_data], epochs=10)
  25. model.evaluate([title_data, text_body_data, tags_data], [priority_data, department_data])
  26. priority_preds, department_preds = model.predict(
  27.         [title_data, text_body_data, tags_data]
  28. )
复制代码
运行结果:
  1. Epoch 1/10
  2. 40/40 [==============================] - 1s 2ms/step - loss: 4.5477 - priority_loss: 0.1296 - department_loss: 4.4181 - priority_mean_absolute_error: 0.2958 - department_accuracy: 0.2766
  3. Epoch 2/10
  4. 40/40 [==============================] - 0s 2ms/step - loss: 4.1786 - priority_loss: 0.1377 - department_loss: 4.0410 - priority_mean_absolute_error: 0.3057 - department_accuracy: 0.3273
  5. Epoch 3/10
  6. 40/40 [==============================] - 0s 2ms/step - loss: 4.8698 - priority_loss: 0.1714 - department_loss: 4.6984 - priority_mean_absolute_error: 0.3389 - department_accuracy: 0.3023
  7. Epoch 4/10
  8. 40/40 [==============================] - 0s 2ms/step - loss: 5.5446 - priority_loss: 0.2163 - department_loss: 5.3283 - priority_mean_absolute_error: 0.3830 - department_accuracy: 0.3195
  9. Epoch 5/10
  10. 40/40 [==============================] - 0s 2ms/step - loss: 7.1691 - priority_loss: 0.2945 - department_loss: 6.8746 - priority_mean_absolute_error: 0.4610 - department_accuracy: 0.3102
  11. Epoch 6/10
  12. 40/40 [==============================] - 0s 2ms/step - loss: 7.9824 - priority_loss: 0.3229 - department_loss: 7.6595 - priority_mean_absolute_error: 0.4873 - department_accuracy: 0.2773
  13. Epoch 7/10
  14. 40/40 [==============================] - 0s 2ms/step - loss: 9.4634 - priority_loss:
复制代码

本帖子中包含更多资源

您需要 登录 才可以下载或查看,没有账号?立即注册

×
回复

使用道具 举报

登录后关闭弹窗

登录参与点评抽奖  加入IT实名职场社区
去登录
快速回复 返回顶部 返回列表