一、代码示例
- from tensorflow import keras
- from tensorflow.keras import layers
- import numpy as np
- # 定义 多输入 多输出的模型
- vocabulary_size = 1000
- num_tags = 100
- num_departments = 4
- title = keras.Input(shape=(vocabulary_size,), name = "title")
- text_body = keras.Input(shape=(vocabulary_size,), name = "text_body")
- tags = keras.Input(shape=(num_tags,), name = "tags")
- features = layers.Concatenate() ([title, text_body, tags])
- features = layers.Dense(64, activation = "relu") (features)
- priority = layers.Dense(1, activation = "sigmoid", name = "priority") (features)
- department = layers.Dense(num_departments, activation = "softmax", name = "department") (features)
- model = keras.Model(inputs=[title, text_body, tags], outputs=[priority, department])
- # 训练多输入 多输出的模型
- num_samples = 1280
- title_data = np.random.randint(0, 2, size=(num_samples, vocabulary_size))
- text_body_data = np.random.randint(0, 2, size=(num_samples, vocabulary_size))
- tags_data = np.random.randint(0, 2, size=(num_samples, num_tags))
- priority_data = np.random.random(size=(num_samples, 1))
- department_data = np.random.randint(0, 2, size=(num_samples, num_departments))
- model.compile(optimizer="rmsprop", loss=["mean_squared_error", "categorical_crossentropy"], metrics=[["mean_absolute_error"], ["accuracy"]])
- model.fit([title_data, text_body_data, tags_data], [priority_data, department_data], epochs=10)
- model.evaluate([title_data, text_body_data, tags_data], [priority_data, department_data])
- priority_preds, department_preds = model.predict(
- [title_data, text_body_data, tags_data]
- )
复制代码 运行结果:
- Epoch 1/10
- 40/40 [==============================] - 1s 2ms/step - loss: 4.5477 - priority_loss: 0.1296 - department_loss: 4.4181 - priority_mean_absolute_error: 0.2958 - department_accuracy: 0.2766
- Epoch 2/10
- 40/40 [==============================] - 0s 2ms/step - loss: 4.1786 - priority_loss: 0.1377 - department_loss: 4.0410 - priority_mean_absolute_error: 0.3057 - department_accuracy: 0.3273
- Epoch 3/10
- 40/40 [==============================] - 0s 2ms/step - loss: 4.8698 - priority_loss: 0.1714 - department_loss: 4.6984 - priority_mean_absolute_error: 0.3389 - department_accuracy: 0.3023
- Epoch 4/10
- 40/40 [==============================] - 0s 2ms/step - loss: 5.5446 - priority_loss: 0.2163 - department_loss: 5.3283 - priority_mean_absolute_error: 0.3830 - department_accuracy: 0.3195
- Epoch 5/10
- 40/40 [==============================] - 0s 2ms/step - loss: 7.1691 - priority_loss: 0.2945 - department_loss: 6.8746 - priority_mean_absolute_error: 0.4610 - department_accuracy: 0.3102
- Epoch 6/10
- 40/40 [==============================] - 0s 2ms/step - loss: 7.9824 - priority_loss: 0.3229 - department_loss: 7.6595 - priority_mean_absolute_error: 0.4873 - department_accuracy: 0.2773
- Epoch 7/10
- 40/40 [==============================] - 0s 2ms/step - loss: 9.4634 - priority_loss:
复制代码 |