matplotlib 动态显示训练过程中的数据和模子的决策边界 ...

打印 上一主题 下一主题

主题 564|帖子 564|积分 1692

Github



  • https://github.com/matplotlib/matplotlib
官网



  • https://matplotlib.org/stable/
文档



  • https://matplotlib.org/stable/api/index.html
简介

matplotlib 是 Python 中最常用的绘图库之一,用于创建各种类型的静态、动态和交互式可视化。
动态显示训练过程中的数据和模子的决策边界


安装

  1. pip install tensorflow==2.13.1
  2. pip install matplotlib==3.7.5
  3. pip install numpy==1.24.3
复制代码
源码

  1. import numpy as np
  2. import tensorflow as tf
  3. from tensorflow.keras.models import Sequential
  4. from tensorflow.keras.layers import Dense
  5. import matplotlib.pyplot as plt
  6. from matplotlib.colors import ListedColormap
  7. # 生成数据
  8. np.random.seed(0)
  9. num_samples_per_class = 500
  10. negative_samples = np.random.multivariate_normal(
  11.     mean=[0, 3],
  12.     cov=[[1, 0.5], [0.5, 1]],
  13.     size=num_samples_per_class
  14. )
  15. positive_samples = np.random.multivariate_normal(
  16.     mean=[3, 0],
  17.     cov=[[1, 0.5], [0.5, 1]],
  18.     size=num_samples_per_class
  19. )
  20. inputs = np.vstack((negative_samples, positive_samples)).astype(np.float32)
  21. targets = np.vstack((np.zeros((num_samples_per_class, 1)), np.ones((num_samples_per_class, 1)))).astype(np.float32)
  22. # 将数据分为训练集和测试集
  23. train_size = int(0.8 * len(inputs))
  24. X_train, X_test = inputs[:train_size], inputs[train_size:]
  25. y_train, y_test = targets[:train_size], targets[train_size:]
  26. # 构建二分类模型
  27. model = Sequential([
  28.     # 输入层:输入形状为 (2,)
  29.     # 第一个隐藏层:包含 4 个节点,激活函数使用 ReLU
  30.     Dense(4, activation='relu', input_shape=(2,)),
  31.    
  32.     # 输出层:包含 1 个节点,激活函数使用 Sigmoid(因为是二分类问题)
  33.     Dense(1, activation='sigmoid')
  34. ])
  35. # 编译模型
  36. # 指定优化器为 Adam,损失函数为二分类交叉熵,评估指标为准确率
  37. model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
  38. # 准备绘图
  39. fig, ax = plt.subplots()
  40. cmap_light = ListedColormap(['#FFAAAA', '#AAAAFF'])
  41. cmap_bold = ListedColormap(['#FF0000', '#0000FF'])
  42. # 动态绘制函数
  43. def plot_decision_boundary(epoch, logs):
  44.     ax.clear()
  45.     x_min, x_max = X_train[:, 0].min() - 1, X_train[:, 0].max() + 1
  46.     y_min, y_max = X_train[:, 1].min() - 1, X_train[:, 1].max() + 1
  47.     xx, yy = np.meshgrid(np.arange(x_min, x_max, 0.1),
  48.                          np.arange(y_min, y_max, 0.1))
  49.     grid = np.c_[xx.ravel(), yy.ravel()]
  50.     probs = model.predict(grid).reshape(xx.shape)
  51.     ax.contourf(xx, yy, probs, alpha=0.8, cmap=cmap_light)
  52.     ax.scatter(X_train[:, 0], X_train[:, 1], c=y_train[:, 0], edgecolor='k', cmap=cmap_bold)
  53.     ax.set_title(f'Epoch {epoch+1}')
  54.     plt.draw()
  55.     plt.pause(0.01)
  56. # 自定义回调函数
  57. class PlotCallback(tf.keras.callbacks.Callback):
  58.     def on_epoch_end(self, epoch, logs=None):
  59.         plot_decision_boundary(epoch, logs)
  60. # 训练模型并动态显示
  61. plot_callback = PlotCallback()
  62. model.fit(X_train, y_train, epochs=50, batch_size=16, callbacks=[plot_callback])
  63. # 评估模型
  64. loss, accuracy = model.evaluate(X_test, y_test)
  65. print(f"Test Loss: {loss}")
  66. print(f"Test Accuracy: {accuracy}")
  67. plt.show()
复制代码
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。

本帖子中包含更多资源

您需要 登录 才可以下载或查看,没有账号?立即注册

x
回复

使用道具 举报

0 个回复

倒序浏览

快速回复

您需要登录后才可以回帖 登录 or 立即注册

本版积分规则

吴旭华

金牌会员
这个人很懒什么都没写!

标签云

快速回复 返回顶部 返回列表