GNN入门与实践——基于GraphSAGE在Cora数据集上的节点分类研究 ...

打印 上一主题 下一主题

主题 868|帖子 868|积分 2604

Hi,大家好,我是半亩花海。本文介绍了图神经网络(GNN)中的一种紧张算法——GraphSAGE,其通过采样邻居节点聚合信息,能够高效地处理大规模图数据,并通过一个完整的代码示例(包罗数据预处理、模子定义、训练过程、验证与测试以及结果可视化)展示了如安在 Cora 数据集上实现节点分类任务。
  目次
一、为什么我们必要图神经网络?
二、什么是 GraphSAGE?
(一)概念
(二)焦点思想
(三)数学公式
三、基于Cora数据集的GraphSAGE实现
(一)研究过程
(二)结果分析
四、GraphSAGE的优势与未来预测

一、为什么我们必要图神经网络?

近年来,随着深度学习的快速发展,神经网络在图像、文本和语音等领域取得了显著的乐成。然而,这些传统方法重要实用于欧几里得数据(如图像和序列),而许多实际世界中的数据本质上是图布局的,例如社交网络、分子布局、知识图谱等。传统的神经网络难以直接处理这种非欧几里得数据。
图神经网络(Graph Neural Network, GNN) 的出现为解决这一标题提供了新的思路。它通过建模节点之间的关系,能够有效地捕捉图布局中的复杂模式。GNN 已经在推荐系统、药物发现、交通预测等领域显现出巨大的潜力。
本文将通过一个详细的 GraphSAGE 示例,深入探究 GNN 的基本原理、实现细节以及其在实际任务中的应用。

二、什么是 GraphSAGE?

(一)概念

GraphSAGE(Graph Sample and Aggregation)是一种基于采样的图神经网络算法。与传统的图卷积网络(GCN)差别,GraphSAGE 不依靠于整个图的毗邻矩阵举行计算,而是通过对邻居节点举行采样和聚合来天生节点表示。这种方法使得 GraphSAGE 更加高效且可扩展,尤其实用于大规模图数据
(二)焦点思想



  • 采样(Sampling) :为了减少计算开销,GraphSAGE 对每个节点的邻居举行随机采样,而不是使用所有邻居。
  • 聚合(Aggregation) :通过聚合采样邻居的信息,更新目标节点的特征表示。常见的聚合方式包罗均值聚合(mean)、最大池化(max-pooling)等。
  • 逐层流传(Layer-wise Propagation) :每一层都会根据前一层的节点表示和邻居信息天生新的节点表示。
(三)数学公式

假设我们有一个图
,此中
是节点集合,
 是边集合。对于第
层,目标节点
的表示
​ 可以通过以下公式计算:

此中:


  •  表示节点
    的邻居集合;
  •  是聚合函数,例如均值聚合;
  • 是可学习的权重矩阵;
  • 是激活函数,例如

三、基于Cora数据集的GraphSAGE实现

下面我们将通过一个完整的代码示例,展示如何使用GraphSAGE在Cora数据集上举行节点分类任务。
   数据集及源代码链接:PyG-GraphSAGE(直接Download下来就行,好像有一处没加右括号,改正后直接运行main.py即可复现)。
  (一)研究过程

1. 数据预处理

首先,我们加载 Cora 数据集并对其举行归一化处理:
  1. import torch
  2. import numpy as np
  3. import torch.nn as nn
  4. import torch.optim as optim
  5. import matplotlib.pyplot as plt
  6. from net import GraphSage
  7. from data import CoraData
  8. from data import CiteseerData
  9. from data import PubmedData
  10. from sampling import multihop_sampling
  11. from collections import namedtuple
  12. # 数据集选择
  13. dataset = "cora"
  14. assert dataset in ["cora", "citeseer", "pubmed"]
  15. # 层数选择
  16. num_layers = 2
  17. assert num_layers in [2, 3]
  18. # 设置输入维度、隐藏层维度和邻居采样数量
  19. if dataset == "cora":
  20.     INPUT_DIM = 1433  # 输入维度
  21.     if num_layers == 2:
  22.         # Note: 采样的邻居阶数需要与GCN的层数保持一致
  23.         HIDDEN_DIM = [256, 7]  # 隐藏单元节点数(2层模型,最后一个是输出的类别)
  24.         NUM_NEIGHBORS_LIST = [10, 10]  # 每阶采样邻居的节点数
  25.     else:
  26.         # Note: 采样的邻居阶数需要与GCN的层数保持一致
  27.         HIDDEN_DIM = [256, 128, 7]  # 隐藏单元节点数(2层模型,最后一个是输出的类别)
  28.         NUM_NEIGHBORS_LIST = [10, 5, 5]  # 每阶采样邻居的节点数
  29. elif dataset == "citeseer":
  30.     INPUT_DIM = 3703  # 输入维度
  31.     if num_layers == 2:
  32.         # Note: 采样的邻居阶数需要与GCN的层数保持一致
  33.         HIDDEN_DIM = [256, 6]  # 隐藏单元节点数(2层模型,最后一个是输出的类别)
  34.         NUM_NEIGHBORS_LIST = [10, 10]  # 每阶采样邻居的节点数
  35.     else:
  36.         # Note: 采样的邻居阶数需要与GCN的层数保持一致
  37.         HIDDEN_DIM = [256, 128, 6]  # 隐藏单元节点数(2层模型,最后一个是输出的类别)
  38.         NUM_NEIGHBORS_LIST = [10, 5, 5]  # 每阶采样邻居的节点数
  39. else:
  40.     INPUT_DIM = 500  # 输入维度
  41.     if num_layers == 2:
  42.         # Note: 采样的邻居阶数需要与GCN的层数保持一致
  43.         HIDDEN_DIM = [256, 3]  # 隐藏单元节点数(2层模型,最后一个是输出的类别)
  44.         NUM_NEIGHBORS_LIST = [10, 10]  # 每阶采样邻居的节点数
  45.     else:
  46.         # Note: 采样的邻居阶数需要与GCN的层数保持一致
  47.         HIDDEN_DIM = [256, 128, 3]  # 隐藏单元节点数(2层模型,最后一个是输出的类别)
  48.         NUM_NEIGHBORS_LIST = [10, 5, 5]  # 每阶采样邻居的节点数
  49. # 定义超参数
  50. BATCH_SIZE = 16  # 批处理大小
  51. EPOCHS = 10  # 训练轮数
  52. NUM_BATCH_PER_EPOCH = 20  # 每个epoch循环的批次数
  53. if dataset == "citeseer":
  54.     LEARNING_RATE = 0.1  # 学习率
  55. else:
  56.     LEARNING_RATE = 0.01
  57. DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
  58. # 数据结构定义
  59. Data = namedtuple('Data', ['x', 'y', 'adjacency_dict', 'train_mask', 'val_mask', 'test_mask'])
  60. # 载入数据
  61. if dataset == "cora":
  62.     data = CoraData().data
  63. elif dataset == "citeseer":
  64.     data = CiteseerData().data
  65. else:
  66.     data = PubmedData().data
  67. # 数据归一化
  68. if dataset == "citeseer":
  69.     x = data.x
  70. else:
  71.     x = data.x / data.x.sum(1, keepdims=True)  # 归一化数据,使得每一行和为1
复制代码
说明:


  • INPUT_DIM 是节点特征的维度;
  • HIDDEN_DIM 是隐藏层的维度列表;
  • NUM_NEIGHBORS_LIST 是每层采样的邻居数量;
  • BATCH_SIZE 是每次训练时使用的样本数量;
  • EPOCHS 是总的训练轮数;
  • NUM_BATCH_PER_EPOCH 是每个 epoch 中的批次数量;
  • LEARNING_RATE 是学习率;
  • DEVICE 是使用的装备(CPU 或 GPU)。
2. 定义训练、验证、测试集

接下来,我们将数据集分别为训练集、验证集和测试集:
  1. # 定义训练、验证、测试集
  2. train_index = np.where(data.train_mask)[0]
  3. train_label = data.y
  4. val_index = np.where(data.val_mask)[0]
  5. test_index = np.where(data.test_mask)[0]
复制代码
说明:


  • train_index 是训练集的索引;
  • train_label 是训练集的标签;
  • val_index 是验证集的索引;
  • test_index 是测试集的索引。
3. 实例化模子

我们实例化一个 GraphSAGE 模子,并指定输入维度、隐藏层维度和邻居采样数量:
  1. # 实例化模型
  2. model = GraphSage(
  3.     input_dim=INPUT_DIM,
  4.     hidden_dim=HIDDEN_DIM,
  5.     num_neighbors_list=NUM_NEIGHBORS_LIST,
  6.     aggr_neighbor_method="mean",
  7.     aggr_hidden_method="sum"
  8. ).to(DEVICE)
  9. print(model)
复制代码
说明:


  • input_dim 是节点特征的维度;
  • hidden_dim 是隐藏层的维度列表;
  • num_neighbors_list 是每层采样的邻居数量;
  • aggr_neighbor_method 是邻居聚合的方式(例如均值聚合);
  • aggr_hidden_method 是隐藏层聚合的方式(例如求和)。
4. 定义损失函数和优化器

我们使用交织熵损失函数和 Adam 优化器来训练模子:
  1. # 定义损失函数和优化器
  2. criterion = nn.CrossEntropyLoss().to(DEVICE)
  3. optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=5e-4)
复制代码
说明:


  • criterion 是交织熵损失函数;
  • optimizer 是 Adam 优化器,带有权重衰减(L2 正则化)。
5. 定义训练函数

训练过程分为以下几个步骤:
(1)采样邻居:对每个批次的节点举行多跳采样,获取其邻居节点的特征。
(2)前向流传:将采样得到的节点特征送入模子,计算节点表示。
(3)损失计算:使用交织熵损失函数计算损失,并通过反向流传更新模子参数。
  1. # 定义训练函数
  2. def train():
  3.     train_losses = []
  4.     train_acces = []
  5.     val_losses = []
  6.     val_acces = []
  7.     model.train()  # 训练模式
  8.     for e in range(EPOCHS):
  9.         train_loss = 0
  10.         train_acc = 0
  11.         val_loss = 0
  12.         val_acc = 0
  13.         if e % 5 == 0:
  14.             optimizer.param_groups[0]['lr'] *= 0.1  # 学习率衰减
  15.         for batch in range(NUM_BATCH_PER_EPOCH):  # 每个epoch循环的批次数
  16.             # 随机从训练集中抽取batch_size个节点(batch_size,num_train_node)
  17.             batch_src_index = np.random.choice(train_index, size=(BATCH_SIZE,))
  18.             # 根据训练节点提取其标签(batch_size,num_train_node)
  19.             batch_src_label = torch.from_numpy(train_label[batch_src_index]).long().to(DEVICE)
  20.             # 进行多跳采样(num_layers+1,num_node)
  21.             batch_sampling_result = multihop_sampling(batch_src_index, NUM_NEIGHBORS_LIST, data.adjacency_dict)
  22.             # 根据采样的节点id构造采样节点特征(num_layers+1,num_node,input_dim)
  23.             batch_sampling_x = [torch.from_numpy(x[idx]).float().to(DEVICE) for idx in batch_sampling_result]
  24.             # 送入模型开始训练
  25.             batch_train_logits = model(batch_sampling_x)
  26.             # 计算损失
  27.             loss = criterion(batch_train_logits, batch_src_label)
  28.             train_loss += loss.item()
  29.             # 更新参数
  30.             optimizer.zero_grad()
  31.             loss.backward()  # 反向传播计算参数的梯度
  32.             optimizer.step()  # 使用优化方法进行梯度更新
  33.             # 计算训练精度
  34.             _, pred = torch.max(batch_train_logits, dim=1)
  35.             correct = (pred == batch_src_label).sum().item()
  36.             acc = correct / BATCH_SIZE
  37.             train_acc += acc
  38.             validate_loss, validate_acc = validate()
  39.             val_loss += validate_loss
  40.             val_acc += validate_acc
  41.             print(
  42.                 "Epoch {:03d} Batch {:03d} train_loss: {:.4f} train_acc: {:.4f} val_loss: {:.4f} val_acc: {:.4f}".format
  43.                 (e, batch, loss.item(), acc, validate_loss, validate_acc))
  44.         train_losses.append(train_loss / NUM_BATCH_PER_EPOCH)
  45.         train_acces.append(train_acc / NUM_BATCH_PER_EPOCH)
  46.         val_losses.append(val_loss / NUM_BATCH_PER_EPOCH)
  47.         val_acces.append(val_acc / NUM_BATCH_PER_EPOCH)
  48.         # 测试
  49.         test()
  50.     res_plot(EPOCHS, train_losses, train_acces, val_losses, val_acces)
复制代码
说明:


  • train() 函数负责训练模子,记载训练和验证的损失和准确率。
  • multihop_sampling 函数用于对节点举行多跳采样。
  • model 函数负责前向流传,计算节点表示。
  • criterion 函数计算损失。
  • optimizer 函数更新模子参数。
  • validate() 函数用于验证模子在验证集上的性能。
  • test() 函数用于测试模子在测试集上的性能。
  • res_plot 函数用于绘制训练和验证过程中的损失和准确率曲线。
6. 定义验证与测试函数

在验证和测试阶段,我们关闭梯度计算,并评估模子在验证集和测试集上的性能:
  1. # 定义测试函数
  2. def validate():
  3.     model.eval()  # 测试模式
  4.     with torch.no_grad():  # 关闭梯度
  5.         val_sampling_result = multihop_sampling(val_index, NUM_NEIGHBORS_LIST, data.adjacency_dict)
  6.         val_x = [torch.from_numpy(x[idx]).float().to(DEVICE) for idx in val_sampling_result]
  7.         val_logits = model(val_x)
  8.         val_label = torch.from_numpy(data.y[val_index]).long().to(DEVICE)
  9.         loss = criterion(val_logits, val_label)
  10.         predict_y = val_logits.max(1)[1]
  11.         accuarcy = torch.eq(predict_y, val_label).float().mean().item()
  12.         return loss.item(), accuarcy
  13. # 定义测试函数
  14. def test():
  15.     model.eval()  # 测试模式
  16.     with torch.no_grad():  # 关闭梯度
  17.         test_sampling_result = multihop_sampling(test_index, NUM_NEIGHBORS_LIST, data.adjacency_dict)
  18.         test_x = [torch.from_numpy(x[idx]).float().to(DEVICE) for idx in test_sampling_result]
  19.         test_logits = model(test_x)
  20.         test_label = torch.from_numpy(data.y[test_index]).long().to(DEVICE)
  21.         predict_y = test_logits.max(1)[1]
  22.         accuarcy = torch.eq(predict_y, test_label).float().mean().item()
  23.         print("Test Accuracy: ", accuarcy)
复制代码
说明:


  • res_plot 函数用于绘制训练和验证过程中的损失和准确率曲线,并生存图像。
7. 可视化训练与验证过程

为了直观地观察模子在训练和验证过程中的体现,我们通过绘制损失和准确率曲线来分析模子的收敛性和性能。这段代码实现了训练损失、训练准确率、验证损失和验证准确率的可视化,并将结果生存为图像文件。
  1. def res_plot(epoch, train_losses, train_acces, val_losses, val_acces):
  2.     epoches = np.arange(0, epoch, 1)
  3.     plt.figure()
  4.     ax = plt.subplot(1, 2, 1)
  5.     # 画出训练结果
  6.     plt.plot(epoches, train_losses, 'b', label='train_loss')
  7.     plt.plot(epoches, train_acces, 'r', label='train_acc')
  8.     # plt.setp(ax.get_xticklabels())
  9.     plt.legend()
  10.     plt.subplot(1, 2, 2, sharey=ax)
  11.     # 画出训练结果
  12.     plt.plot(epoches, val_losses, 'k', label='val_loss')
  13.     plt.plot(epoches, val_acces, 'g', label='val_acc')
  14.     plt.legend()
  15.     plt.savefig('res_plot.jpg')
  16.     plt.show()
复制代码
main函数:
  1. # main函数,程序入口
  2. if __name__ == '__main__':
  3.     train()
复制代码
(二)结果分析

(1)运行结果

(2)准确与损失率曲线 

从曲线上可以看出,团体准确率比力高且趋于稳定,但经充实训练之后,val_loss值仍旧均位于1以上,可能与该模子的学习率过高、数据集处理不当、邻居采样不足等标题,以是此实例demo有待改进。 
四、GraphSAGE的优势与未来预测

通过上述实验,我们可以看到GraphSAGE在Cora数据集上的体现非常出色。相比于传统的GCN,GraphSAGE的采样机制使其能够更好地扩展到大规模图数据,同时保持较高的分类精度。
(1)优势


  • 高效性 :通过采样邻居节点,制止了对整个图的计算,显著降低了时间和空间复杂度。
  • 机动性 :支持多种聚合方式,可以根据详细任务选择符合的战略。
  • 可扩展性 :实用于动态图和超大规模图。
(2)未来预测
尽管GraphSAGE已经取得了显著的结果,但仍有许多值得探索的方向:


  • 更高效的采样战略 :如何计划更智能的采样方法,进一步提升模子性能?
  • 跨领域应用 :如何将GNN应用于更多领域,例如健康估计、寿命预测、生物信息学、金融分析等?
  • 理论分析 :深入研究GNN的表达能力和泛化能力。

免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。

本帖子中包含更多资源

您需要 登录 才可以下载或查看,没有账号?立即注册

x
回复

使用道具 举报

0 个回复

倒序浏览

快速回复

您需要登录后才可以回帖 登录 or 立即注册

本版积分规则

没腿的鸟

金牌会员
这个人很懒什么都没写!
快速回复 返回顶部 返回列表