IT评测·应用市场-qidao123.com技术社区

标题: GNN入门与实践——基于GraphSAGE在Cora数据集上的节点分类研究 [打印本页]

作者: 没腿的鸟    时间: 2025-3-5 02:03
标题: GNN入门与实践——基于GraphSAGE在Cora数据集上的节点分类研究
Hi,大家好,我是半亩花海。本文介绍了图神经网络(GNN)中的一种紧张算法——GraphSAGE,其通过采样邻居节点聚合信息,能够高效地处理大规模图数据,并通过一个完整的代码示例(包罗数据预处理、模子定义、训练过程、验证与测试以及结果可视化)展示了如安在 Cora 数据集上实现节点分类任务。
  目次
一、为什么我们必要图神经网络?
二、什么是 GraphSAGE?
(一)概念
(二)焦点思想
(三)数学公式
三、基于Cora数据集的GraphSAGE实现
(一)研究过程
(二)结果分析
四、GraphSAGE的优势与未来预测

一、为什么我们必要图神经网络?

近年来,随着深度学习的快速发展,神经网络在图像、文本和语音等领域取得了显著的乐成。然而,这些传统方法重要实用于欧几里得数据(如图像和序列),而许多实际世界中的数据本质上是图布局的,例如社交网络、分子布局、知识图谱等。传统的神经网络难以直接处理这种非欧几里得数据。
图神经网络(Graph Neural Network, GNN) 的出现为解决这一标题提供了新的思路。它通过建模节点之间的关系,能够有效地捕捉图布局中的复杂模式。GNN 已经在推荐系统、药物发现、交通预测等领域显现出巨大的潜力。
本文将通过一个详细的 GraphSAGE 示例,深入探究 GNN 的基本原理、实现细节以及其在实际任务中的应用。

二、什么是 GraphSAGE?

(一)概念

GraphSAGE(Graph Sample and Aggregation)是一种基于采样的图神经网络算法。与传统的图卷积网络(GCN)差别,GraphSAGE 不依靠于整个图的毗邻矩阵举行计算,而是通过对邻居节点举行采样和聚合来天生节点表示。这种方法使得 GraphSAGE 更加高效且可扩展,尤其实用于大规模图数据
(二)焦点思想


(三)数学公式

假设我们有一个图
,此中
是节点集合,
 是边集合。对于第
层,目标节点
的表示
​ 可以通过以下公式计算:

此中:

三、基于Cora数据集的GraphSAGE实现

下面我们将通过一个完整的代码示例,展示如何使用GraphSAGE在Cora数据集上举行节点分类任务。
   数据集及源代码链接:PyG-GraphSAGE(直接Download下来就行,好像有一处没加右括号,改正后直接运行main.py即可复现)。
  (一)研究过程

1. 数据预处理

首先,我们加载 Cora 数据集并对其举行归一化处理:
  1. import torch
  2. import numpy as np
  3. import torch.nn as nn
  4. import torch.optim as optim
  5. import matplotlib.pyplot as plt
  6. from net import GraphSage
  7. from data import CoraData
  8. from data import CiteseerData
  9. from data import PubmedData
  10. from sampling import multihop_sampling
  11. from collections import namedtuple
  12. # 数据集选择
  13. dataset = "cora"
  14. assert dataset in ["cora", "citeseer", "pubmed"]
  15. # 层数选择
  16. num_layers = 2
  17. assert num_layers in [2, 3]
  18. # 设置输入维度、隐藏层维度和邻居采样数量
  19. if dataset == "cora":
  20.     INPUT_DIM = 1433  # 输入维度
  21.     if num_layers == 2:
  22.         # Note: 采样的邻居阶数需要与GCN的层数保持一致
  23.         HIDDEN_DIM = [256, 7]  # 隐藏单元节点数(2层模型,最后一个是输出的类别)
  24.         NUM_NEIGHBORS_LIST = [10, 10]  # 每阶采样邻居的节点数
  25.     else:
  26.         # Note: 采样的邻居阶数需要与GCN的层数保持一致
  27.         HIDDEN_DIM = [256, 128, 7]  # 隐藏单元节点数(2层模型,最后一个是输出的类别)
  28.         NUM_NEIGHBORS_LIST = [10, 5, 5]  # 每阶采样邻居的节点数
  29. elif dataset == "citeseer":
  30.     INPUT_DIM = 3703  # 输入维度
  31.     if num_layers == 2:
  32.         # Note: 采样的邻居阶数需要与GCN的层数保持一致
  33.         HIDDEN_DIM = [256, 6]  # 隐藏单元节点数(2层模型,最后一个是输出的类别)
  34.         NUM_NEIGHBORS_LIST = [10, 10]  # 每阶采样邻居的节点数
  35.     else:
  36.         # Note: 采样的邻居阶数需要与GCN的层数保持一致
  37.         HIDDEN_DIM = [256, 128, 6]  # 隐藏单元节点数(2层模型,最后一个是输出的类别)
  38.         NUM_NEIGHBORS_LIST = [10, 5, 5]  # 每阶采样邻居的节点数
  39. else:
  40.     INPUT_DIM = 500  # 输入维度
  41.     if num_layers == 2:
  42.         # Note: 采样的邻居阶数需要与GCN的层数保持一致
  43.         HIDDEN_DIM = [256, 3]  # 隐藏单元节点数(2层模型,最后一个是输出的类别)
  44.         NUM_NEIGHBORS_LIST = [10, 10]  # 每阶采样邻居的节点数
  45.     else:
  46.         # Note: 采样的邻居阶数需要与GCN的层数保持一致
  47.         HIDDEN_DIM = [256, 128, 3]  # 隐藏单元节点数(2层模型,最后一个是输出的类别)
  48.         NUM_NEIGHBORS_LIST = [10, 5, 5]  # 每阶采样邻居的节点数
  49. # 定义超参数
  50. BATCH_SIZE = 16  # 批处理大小
  51. EPOCHS = 10  # 训练轮数
  52. NUM_BATCH_PER_EPOCH = 20  # 每个epoch循环的批次数
  53. if dataset == "citeseer":
  54.     LEARNING_RATE = 0.1  # 学习率
  55. else:
  56.     LEARNING_RATE = 0.01
  57. DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
  58. # 数据结构定义
  59. Data = namedtuple('Data', ['x', 'y', 'adjacency_dict', 'train_mask', 'val_mask', 'test_mask'])
  60. # 载入数据
  61. if dataset == "cora":
  62.     data = CoraData().data
  63. elif dataset == "citeseer":
  64.     data = CiteseerData().data
  65. else:
  66.     data = PubmedData().data
  67. # 数据归一化
  68. if dataset == "citeseer":
  69.     x = data.x
  70. else:
  71.     x = data.x / data.x.sum(1, keepdims=True)  # 归一化数据,使得每一行和为1
复制代码
说明:

2. 定义训练、验证、测试集

接下来,我们将数据集分别为训练集、验证集和测试集:
  1. # 定义训练、验证、测试集
  2. train_index = np.where(data.train_mask)[0]
  3. train_label = data.y
  4. val_index = np.where(data.val_mask)[0]
  5. test_index = np.where(data.test_mask)[0]
复制代码
说明:

3. 实例化模子

我们实例化一个 GraphSAGE 模子,并指定输入维度、隐藏层维度和邻居采样数量:
  1. # 实例化模型
  2. model = GraphSage(
  3.     input_dim=INPUT_DIM,
  4.     hidden_dim=HIDDEN_DIM,
  5.     num_neighbors_list=NUM_NEIGHBORS_LIST,
  6.     aggr_neighbor_method="mean",
  7.     aggr_hidden_method="sum"
  8. ).to(DEVICE)
  9. print(model)
复制代码
说明:

4. 定义损失函数和优化器

我们使用交织熵损失函数和 Adam 优化器来训练模子:
  1. # 定义损失函数和优化器
  2. criterion = nn.CrossEntropyLoss().to(DEVICE)
  3. optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=5e-4)
复制代码
说明:

5. 定义训练函数

训练过程分为以下几个步骤:
(1)采样邻居:对每个批次的节点举行多跳采样,获取其邻居节点的特征。
(2)前向流传:将采样得到的节点特征送入模子,计算节点表示。
(3)损失计算:使用交织熵损失函数计算损失,并通过反向流传更新模子参数。
  1. # 定义训练函数
  2. def train():
  3.     train_losses = []
  4.     train_acces = []
  5.     val_losses = []
  6.     val_acces = []
  7.     model.train()  # 训练模式
  8.     for e in range(EPOCHS):
  9.         train_loss = 0
  10.         train_acc = 0
  11.         val_loss = 0
  12.         val_acc = 0
  13.         if e % 5 == 0:
  14.             optimizer.param_groups[0]['lr'] *= 0.1  # 学习率衰减
  15.         for batch in range(NUM_BATCH_PER_EPOCH):  # 每个epoch循环的批次数
  16.             # 随机从训练集中抽取batch_size个节点(batch_size,num_train_node)
  17.             batch_src_index = np.random.choice(train_index, size=(BATCH_SIZE,))
  18.             # 根据训练节点提取其标签(batch_size,num_train_node)
  19.             batch_src_label = torch.from_numpy(train_label[batch_src_index]).long().to(DEVICE)
  20.             # 进行多跳采样(num_layers+1,num_node)
  21.             batch_sampling_result = multihop_sampling(batch_src_index, NUM_NEIGHBORS_LIST, data.adjacency_dict)
  22.             # 根据采样的节点id构造采样节点特征(num_layers+1,num_node,input_dim)
  23.             batch_sampling_x = [torch.from_numpy(x[idx]).float().to(DEVICE) for idx in batch_sampling_result]
  24.             # 送入模型开始训练
  25.             batch_train_logits = model(batch_sampling_x)
  26.             # 计算损失
  27.             loss = criterion(batch_train_logits, batch_src_label)
  28.             train_loss += loss.item()
  29.             # 更新参数
  30.             optimizer.zero_grad()
  31.             loss.backward()  # 反向传播计算参数的梯度
  32.             optimizer.step()  # 使用优化方法进行梯度更新
  33.             # 计算训练精度
  34.             _, pred = torch.max(batch_train_logits, dim=1)
  35.             correct = (pred == batch_src_label).sum().item()
  36.             acc = correct / BATCH_SIZE
  37.             train_acc += acc
  38.             validate_loss, validate_acc = validate()
  39.             val_loss += validate_loss
  40.             val_acc += validate_acc
  41.             print(
  42.                 "Epoch {:03d} Batch {:03d} train_loss: {:.4f} train_acc: {:.4f} val_loss: {:.4f} val_acc: {:.4f}".format
  43.                 (e, batch, loss.item(), acc, validate_loss, validate_acc))
  44.         train_losses.append(train_loss / NUM_BATCH_PER_EPOCH)
  45.         train_acces.append(train_acc / NUM_BATCH_PER_EPOCH)
  46.         val_losses.append(val_loss / NUM_BATCH_PER_EPOCH)
  47.         val_acces.append(val_acc / NUM_BATCH_PER_EPOCH)
  48.         # 测试
  49.         test()
  50.     res_plot(EPOCHS, train_losses, train_acces, val_losses, val_acces)
复制代码
说明:

6. 定义验证与测试函数

在验证和测试阶段,我们关闭梯度计算,并评估模子在验证集和测试集上的性能:
  1. # 定义测试函数
  2. def validate():
  3.     model.eval()  # 测试模式
  4.     with torch.no_grad():  # 关闭梯度
  5.         val_sampling_result = multihop_sampling(val_index, NUM_NEIGHBORS_LIST, data.adjacency_dict)
  6.         val_x = [torch.from_numpy(x[idx]).float().to(DEVICE) for idx in val_sampling_result]
  7.         val_logits = model(val_x)
  8.         val_label = torch.from_numpy(data.y[val_index]).long().to(DEVICE)
  9.         loss = criterion(val_logits, val_label)
  10.         predict_y = val_logits.max(1)[1]
  11.         accuarcy = torch.eq(predict_y, val_label).float().mean().item()
  12.         return loss.item(), accuarcy
  13. # 定义测试函数
  14. def test():
  15.     model.eval()  # 测试模式
  16.     with torch.no_grad():  # 关闭梯度
  17.         test_sampling_result = multihop_sampling(test_index, NUM_NEIGHBORS_LIST, data.adjacency_dict)
  18.         test_x = [torch.from_numpy(x[idx]).float().to(DEVICE) for idx in test_sampling_result]
  19.         test_logits = model(test_x)
  20.         test_label = torch.from_numpy(data.y[test_index]).long().to(DEVICE)
  21.         predict_y = test_logits.max(1)[1]
  22.         accuarcy = torch.eq(predict_y, test_label).float().mean().item()
  23.         print("Test Accuracy: ", accuarcy)
复制代码
说明:

7. 可视化训练与验证过程

为了直观地观察模子在训练和验证过程中的体现,我们通过绘制损失和准确率曲线来分析模子的收敛性和性能。这段代码实现了训练损失、训练准确率、验证损失和验证准确率的可视化,并将结果生存为图像文件。
  1. def res_plot(epoch, train_losses, train_acces, val_losses, val_acces):
  2.     epoches = np.arange(0, epoch, 1)
  3.     plt.figure()
  4.     ax = plt.subplot(1, 2, 1)
  5.     # 画出训练结果
  6.     plt.plot(epoches, train_losses, 'b', label='train_loss')
  7.     plt.plot(epoches, train_acces, 'r', label='train_acc')
  8.     # plt.setp(ax.get_xticklabels())
  9.     plt.legend()
  10.     plt.subplot(1, 2, 2, sharey=ax)
  11.     # 画出训练结果
  12.     plt.plot(epoches, val_losses, 'k', label='val_loss')
  13.     plt.plot(epoches, val_acces, 'g', label='val_acc')
  14.     plt.legend()
  15.     plt.savefig('res_plot.jpg')
  16.     plt.show()
复制代码
main函数:
  1. # main函数,程序入口
  2. if __name__ == '__main__':
  3.     train()
复制代码
(二)结果分析

(1)运行结果

(2)准确与损失率曲线 

从曲线上可以看出,团体准确率比力高且趋于稳定,但经充实训练之后,val_loss值仍旧均位于1以上,可能与该模子的学习率过高、数据集处理不当、邻居采样不足等标题,以是此实例demo有待改进。 
四、GraphSAGE的优势与未来预测

通过上述实验,我们可以看到GraphSAGE在Cora数据集上的体现非常出色。相比于传统的GCN,GraphSAGE的采样机制使其能够更好地扩展到大规模图数据,同时保持较高的分类精度。
(1)优势

(2)未来预测
尽管GraphSAGE已经取得了显著的结果,但仍有许多值得探索的方向:


免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。




欢迎光临 IT评测·应用市场-qidao123.com技术社区 (https://dis.qidao123.com/) Powered by Discuz! X3.4