ToB企服应用市场:ToB评测及商务社交产业平台

标题: matplotlib 动态显示训练过程中的数据和模子的决策边界 [打印本页]

作者: 吴旭华    时间: 2024-6-13 21:18
标题: matplotlib 动态显示训练过程中的数据和模子的决策边界
Github


官网


文档


简介

matplotlib 是 Python 中最常用的绘图库之一,用于创建各种类型的静态、动态和交互式可视化。
动态显示训练过程中的数据和模子的决策边界


安装

  1. pip install tensorflow==2.13.1
  2. pip install matplotlib==3.7.5
  3. pip install numpy==1.24.3
复制代码
源码

  1. import numpy as np
  2. import tensorflow as tf
  3. from tensorflow.keras.models import Sequential
  4. from tensorflow.keras.layers import Dense
  5. import matplotlib.pyplot as plt
  6. from matplotlib.colors import ListedColormap
  7. # 生成数据
  8. np.random.seed(0)
  9. num_samples_per_class = 500
  10. negative_samples = np.random.multivariate_normal(
  11.     mean=[0, 3],
  12.     cov=[[1, 0.5], [0.5, 1]],
  13.     size=num_samples_per_class
  14. )
  15. positive_samples = np.random.multivariate_normal(
  16.     mean=[3, 0],
  17.     cov=[[1, 0.5], [0.5, 1]],
  18.     size=num_samples_per_class
  19. )
  20. inputs = np.vstack((negative_samples, positive_samples)).astype(np.float32)
  21. targets = np.vstack((np.zeros((num_samples_per_class, 1)), np.ones((num_samples_per_class, 1)))).astype(np.float32)
  22. # 将数据分为训练集和测试集
  23. train_size = int(0.8 * len(inputs))
  24. X_train, X_test = inputs[:train_size], inputs[train_size:]
  25. y_train, y_test = targets[:train_size], targets[train_size:]
  26. # 构建二分类模型
  27. model = Sequential([
  28.     # 输入层:输入形状为 (2,)
  29.     # 第一个隐藏层:包含 4 个节点,激活函数使用 ReLU
  30.     Dense(4, activation='relu', input_shape=(2,)),
  31.    
  32.     # 输出层:包含 1 个节点,激活函数使用 Sigmoid(因为是二分类问题)
  33.     Dense(1, activation='sigmoid')
  34. ])
  35. # 编译模型
  36. # 指定优化器为 Adam,损失函数为二分类交叉熵,评估指标为准确率
  37. model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
  38. # 准备绘图
  39. fig, ax = plt.subplots()
  40. cmap_light = ListedColormap(['#FFAAAA', '#AAAAFF'])
  41. cmap_bold = ListedColormap(['#FF0000', '#0000FF'])
  42. # 动态绘制函数
  43. def plot_decision_boundary(epoch, logs):
  44.     ax.clear()
  45.     x_min, x_max = X_train[:, 0].min() - 1, X_train[:, 0].max() + 1
  46.     y_min, y_max = X_train[:, 1].min() - 1, X_train[:, 1].max() + 1
  47.     xx, yy = np.meshgrid(np.arange(x_min, x_max, 0.1),
  48.                          np.arange(y_min, y_max, 0.1))
  49.     grid = np.c_[xx.ravel(), yy.ravel()]
  50.     probs = model.predict(grid).reshape(xx.shape)
  51.     ax.contourf(xx, yy, probs, alpha=0.8, cmap=cmap_light)
  52.     ax.scatter(X_train[:, 0], X_train[:, 1], c=y_train[:, 0], edgecolor='k', cmap=cmap_bold)
  53.     ax.set_title(f'Epoch {epoch+1}')
  54.     plt.draw()
  55.     plt.pause(0.01)
  56. # 自定义回调函数
  57. class PlotCallback(tf.keras.callbacks.Callback):
  58.     def on_epoch_end(self, epoch, logs=None):
  59.         plot_decision_boundary(epoch, logs)
  60. # 训练模型并动态显示
  61. plot_callback = PlotCallback()
  62. model.fit(X_train, y_train, epochs=50, batch_size=16, callbacks=[plot_callback])
  63. # 评估模型
  64. loss, accuracy = model.evaluate(X_test, y_test)
  65. print(f"Test Loss: {loss}")
  66. print(f"Test Accuracy: {accuracy}")
  67. plt.show()
复制代码
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。




欢迎光临 ToB企服应用市场:ToB评测及商务社交产业平台 (https://dis.qidao123.com/) Powered by Discuz! X3.4