Hi,大家好,我是半亩花海。本文介绍了图神经网络(GNN)中的一种紧张算法——GraphSAGE,其通过采样邻居节点和聚合信息,能够高效地处理大规模图数据,并通过一个完整的代码示例(包罗数据预处理、模子定义、训练过程、验证与测试以及结果可视化)展示了如安在 Cora 数据集上实现节点分类任务。
目次
一、为什么我们必要图神经网络?
二、什么是 GraphSAGE?
(一)概念
(二)焦点思想
(三)数学公式
三、基于Cora数据集的GraphSAGE实现
(一)研究过程
(二)结果分析
四、GraphSAGE的优势与未来预测
一、为什么我们必要图神经网络?
近年来,随着深度学习的快速发展,神经网络在图像、文本和语音等领域取得了显著的乐成。然而,这些传统方法重要实用于欧几里得数据(如图像和序列),而许多实际世界中的数据本质上是图布局的,例如社交网络、分子布局、知识图谱等。传统的神经网络难以直接处理这种非欧几里得数据。
图神经网络(Graph Neural Network, GNN) 的出现为解决这一标题提供了新的思路。它通过建模节点之间的关系,能够有效地捕捉图布局中的复杂模式。GNN 已经在推荐系统、药物发现、交通预测等领域显现出巨大的潜力。
本文将通过一个详细的 GraphSAGE 示例,深入探究 GNN 的基本原理、实现细节以及其在实际任务中的应用。
二、什么是 GraphSAGE?
(一)概念
GraphSAGE(Graph Sample and Aggregation)是一种基于采样的图神经网络算法。与传统的图卷积网络(GCN)差别,GraphSAGE 不依靠于整个图的毗邻矩阵举行计算,而是通过对邻居节点举行采样和聚合来天生节点表示。这种方法使得 GraphSAGE 更加高效且可扩展,尤其实用于大规模图数据。
(二)焦点思想
- 采样(Sampling) :为了减少计算开销,GraphSAGE 对每个节点的邻居举行随机采样,而不是使用所有邻居。
- 聚合(Aggregation) :通过聚合采样邻居的信息,更新目标节点的特征表示。常见的聚合方式包罗均值聚合(mean)、最大池化(max-pooling)等。
- 逐层流传(Layer-wise Propagation) :每一层都会根据前一层的节点表示和邻居信息天生新的节点表示。
(三)数学公式
假设我们有一个图 ,此中 是节点集合, 是边集合。对于第 层,目标节点 的表示 可以通过以下公式计算:
此中:
- 表示节点 的邻居集合;
- 是聚合函数,例如均值聚合;
- 是可学习的权重矩阵;
- 是激活函数,例如 。
三、基于Cora数据集的GraphSAGE实现
下面我们将通过一个完整的代码示例,展示如何使用GraphSAGE在Cora数据集上举行节点分类任务。
数据集及源代码链接:PyG-GraphSAGE(直接Download下来就行,好像有一处没加右括号,改正后直接运行main.py即可复现)。
(一)研究过程
1. 数据预处理
首先,我们加载 Cora 数据集并对其举行归一化处理:
- import torch
- import numpy as np
- import torch.nn as nn
- import torch.optim as optim
- import matplotlib.pyplot as plt
- from net import GraphSage
- from data import CoraData
- from data import CiteseerData
- from data import PubmedData
- from sampling import multihop_sampling
- from collections import namedtuple
- # 数据集选择
- dataset = "cora"
- assert dataset in ["cora", "citeseer", "pubmed"]
- # 层数选择
- num_layers = 2
- assert num_layers in [2, 3]
- # 设置输入维度、隐藏层维度和邻居采样数量
- if dataset == "cora":
- INPUT_DIM = 1433 # 输入维度
- if num_layers == 2:
- # Note: 采样的邻居阶数需要与GCN的层数保持一致
- HIDDEN_DIM = [256, 7] # 隐藏单元节点数(2层模型,最后一个是输出的类别)
- NUM_NEIGHBORS_LIST = [10, 10] # 每阶采样邻居的节点数
- else:
- # Note: 采样的邻居阶数需要与GCN的层数保持一致
- HIDDEN_DIM = [256, 128, 7] # 隐藏单元节点数(2层模型,最后一个是输出的类别)
- NUM_NEIGHBORS_LIST = [10, 5, 5] # 每阶采样邻居的节点数
- elif dataset == "citeseer":
- INPUT_DIM = 3703 # 输入维度
- if num_layers == 2:
- # Note: 采样的邻居阶数需要与GCN的层数保持一致
- HIDDEN_DIM = [256, 6] # 隐藏单元节点数(2层模型,最后一个是输出的类别)
- NUM_NEIGHBORS_LIST = [10, 10] # 每阶采样邻居的节点数
- else:
- # Note: 采样的邻居阶数需要与GCN的层数保持一致
- HIDDEN_DIM = [256, 128, 6] # 隐藏单元节点数(2层模型,最后一个是输出的类别)
- NUM_NEIGHBORS_LIST = [10, 5, 5] # 每阶采样邻居的节点数
- else:
- INPUT_DIM = 500 # 输入维度
- if num_layers == 2:
- # Note: 采样的邻居阶数需要与GCN的层数保持一致
- HIDDEN_DIM = [256, 3] # 隐藏单元节点数(2层模型,最后一个是输出的类别)
- NUM_NEIGHBORS_LIST = [10, 10] # 每阶采样邻居的节点数
- else:
- # Note: 采样的邻居阶数需要与GCN的层数保持一致
- HIDDEN_DIM = [256, 128, 3] # 隐藏单元节点数(2层模型,最后一个是输出的类别)
- NUM_NEIGHBORS_LIST = [10, 5, 5] # 每阶采样邻居的节点数
- # 定义超参数
- BATCH_SIZE = 16 # 批处理大小
- EPOCHS = 10 # 训练轮数
- NUM_BATCH_PER_EPOCH = 20 # 每个epoch循环的批次数
- if dataset == "citeseer":
- LEARNING_RATE = 0.1 # 学习率
- else:
- LEARNING_RATE = 0.01
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
- # 数据结构定义
- Data = namedtuple('Data', ['x', 'y', 'adjacency_dict', 'train_mask', 'val_mask', 'test_mask'])
- # 载入数据
- if dataset == "cora":
- data = CoraData().data
- elif dataset == "citeseer":
- data = CiteseerData().data
- else:
- data = PubmedData().data
- # 数据归一化
- if dataset == "citeseer":
- x = data.x
- else:
- x = data.x / data.x.sum(1, keepdims=True) # 归一化数据,使得每一行和为1
复制代码 说明:
- INPUT_DIM 是节点特征的维度;
- HIDDEN_DIM 是隐藏层的维度列表;
- NUM_NEIGHBORS_LIST 是每层采样的邻居数量;
- BATCH_SIZE 是每次训练时使用的样本数量;
- EPOCHS 是总的训练轮数;
- NUM_BATCH_PER_EPOCH 是每个 epoch 中的批次数量;
- LEARNING_RATE 是学习率;
- DEVICE 是使用的装备(CPU 或 GPU)。
2. 定义训练、验证、测试集
接下来,我们将数据集分别为训练集、验证集和测试集:
- # 定义训练、验证、测试集
- train_index = np.where(data.train_mask)[0]
- train_label = data.y
- val_index = np.where(data.val_mask)[0]
- test_index = np.where(data.test_mask)[0]
复制代码 说明:
- train_index 是训练集的索引;
- train_label 是训练集的标签;
- val_index 是验证集的索引;
- test_index 是测试集的索引。
3. 实例化模子
我们实例化一个 GraphSAGE 模子,并指定输入维度、隐藏层维度和邻居采样数量:
- # 实例化模型
- model = GraphSage(
- input_dim=INPUT_DIM,
- hidden_dim=HIDDEN_DIM,
- num_neighbors_list=NUM_NEIGHBORS_LIST,
- aggr_neighbor_method="mean",
- aggr_hidden_method="sum"
- ).to(DEVICE)
- print(model)
复制代码 说明:
- input_dim 是节点特征的维度;
- hidden_dim 是隐藏层的维度列表;
- num_neighbors_list 是每层采样的邻居数量;
- aggr_neighbor_method 是邻居聚合的方式(例如均值聚合);
- aggr_hidden_method 是隐藏层聚合的方式(例如求和)。
4. 定义损失函数和优化器
我们使用交织熵损失函数和 Adam 优化器来训练模子:
- # 定义损失函数和优化器
- criterion = nn.CrossEntropyLoss().to(DEVICE)
- optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=5e-4)
复制代码 说明:
- criterion 是交织熵损失函数;
- optimizer 是 Adam 优化器,带有权重衰减(L2 正则化)。
5. 定义训练函数
训练过程分为以下几个步骤:
(1)采样邻居:对每个批次的节点举行多跳采样,获取其邻居节点的特征。
(2)前向流传:将采样得到的节点特征送入模子,计算节点表示。
(3)损失计算:使用交织熵损失函数计算损失,并通过反向流传更新模子参数。
- # 定义训练函数
- def train():
- train_losses = []
- train_acces = []
- val_losses = []
- val_acces = []
- model.train() # 训练模式
- for e in range(EPOCHS):
- train_loss = 0
- train_acc = 0
- val_loss = 0
- val_acc = 0
- if e % 5 == 0:
- optimizer.param_groups[0]['lr'] *= 0.1 # 学习率衰减
- for batch in range(NUM_BATCH_PER_EPOCH): # 每个epoch循环的批次数
- # 随机从训练集中抽取batch_size个节点(batch_size,num_train_node)
- batch_src_index = np.random.choice(train_index, size=(BATCH_SIZE,))
- # 根据训练节点提取其标签(batch_size,num_train_node)
- batch_src_label = torch.from_numpy(train_label[batch_src_index]).long().to(DEVICE)
- # 进行多跳采样(num_layers+1,num_node)
- batch_sampling_result = multihop_sampling(batch_src_index, NUM_NEIGHBORS_LIST, data.adjacency_dict)
- # 根据采样的节点id构造采样节点特征(num_layers+1,num_node,input_dim)
- batch_sampling_x = [torch.from_numpy(x[idx]).float().to(DEVICE) for idx in batch_sampling_result]
- # 送入模型开始训练
- batch_train_logits = model(batch_sampling_x)
- # 计算损失
- loss = criterion(batch_train_logits, batch_src_label)
- train_loss += loss.item()
- # 更新参数
- optimizer.zero_grad()
- loss.backward() # 反向传播计算参数的梯度
- optimizer.step() # 使用优化方法进行梯度更新
- # 计算训练精度
- _, pred = torch.max(batch_train_logits, dim=1)
- correct = (pred == batch_src_label).sum().item()
- acc = correct / BATCH_SIZE
- train_acc += acc
- validate_loss, validate_acc = validate()
- val_loss += validate_loss
- val_acc += validate_acc
- print(
- "Epoch {:03d} Batch {:03d} train_loss: {:.4f} train_acc: {:.4f} val_loss: {:.4f} val_acc: {:.4f}".format
- (e, batch, loss.item(), acc, validate_loss, validate_acc))
- train_losses.append(train_loss / NUM_BATCH_PER_EPOCH)
- train_acces.append(train_acc / NUM_BATCH_PER_EPOCH)
- val_losses.append(val_loss / NUM_BATCH_PER_EPOCH)
- val_acces.append(val_acc / NUM_BATCH_PER_EPOCH)
- # 测试
- test()
- res_plot(EPOCHS, train_losses, train_acces, val_losses, val_acces)
复制代码 说明:
- train() 函数负责训练模子,记载训练和验证的损失和准确率。
- multihop_sampling 函数用于对节点举行多跳采样。
- model 函数负责前向流传,计算节点表示。
- criterion 函数计算损失。
- optimizer 函数更新模子参数。
- validate() 函数用于验证模子在验证集上的性能。
- test() 函数用于测试模子在测试集上的性能。
- res_plot 函数用于绘制训练和验证过程中的损失和准确率曲线。
6. 定义验证与测试函数
在验证和测试阶段,我们关闭梯度计算,并评估模子在验证集和测试集上的性能:
- # 定义测试函数
- def validate():
- model.eval() # 测试模式
- with torch.no_grad(): # 关闭梯度
- val_sampling_result = multihop_sampling(val_index, NUM_NEIGHBORS_LIST, data.adjacency_dict)
- val_x = [torch.from_numpy(x[idx]).float().to(DEVICE) for idx in val_sampling_result]
- val_logits = model(val_x)
- val_label = torch.from_numpy(data.y[val_index]).long().to(DEVICE)
- loss = criterion(val_logits, val_label)
- predict_y = val_logits.max(1)[1]
- accuarcy = torch.eq(predict_y, val_label).float().mean().item()
- return loss.item(), accuarcy
- # 定义测试函数
- def test():
- model.eval() # 测试模式
- with torch.no_grad(): # 关闭梯度
- test_sampling_result = multihop_sampling(test_index, NUM_NEIGHBORS_LIST, data.adjacency_dict)
- test_x = [torch.from_numpy(x[idx]).float().to(DEVICE) for idx in test_sampling_result]
- test_logits = model(test_x)
- test_label = torch.from_numpy(data.y[test_index]).long().to(DEVICE)
- predict_y = test_logits.max(1)[1]
- accuarcy = torch.eq(predict_y, test_label).float().mean().item()
- print("Test Accuracy: ", accuarcy)
复制代码 说明:
- res_plot 函数用于绘制训练和验证过程中的损失和准确率曲线,并生存图像。
7. 可视化训练与验证过程
为了直观地观察模子在训练和验证过程中的体现,我们通过绘制损失和准确率曲线来分析模子的收敛性和性能。这段代码实现了训练损失、训练准确率、验证损失和验证准确率的可视化,并将结果生存为图像文件。
- def res_plot(epoch, train_losses, train_acces, val_losses, val_acces):
- epoches = np.arange(0, epoch, 1)
- plt.figure()
- ax = plt.subplot(1, 2, 1)
- # 画出训练结果
- plt.plot(epoches, train_losses, 'b', label='train_loss')
- plt.plot(epoches, train_acces, 'r', label='train_acc')
- # plt.setp(ax.get_xticklabels())
- plt.legend()
- plt.subplot(1, 2, 2, sharey=ax)
- # 画出训练结果
- plt.plot(epoches, val_losses, 'k', label='val_loss')
- plt.plot(epoches, val_acces, 'g', label='val_acc')
- plt.legend()
- plt.savefig('res_plot.jpg')
- plt.show()
复制代码 main函数:
- # main函数,程序入口
- if __name__ == '__main__':
- train()
复制代码 (二)结果分析
(1)运行结果
(2)准确与损失率曲线
从曲线上可以看出,团体准确率比力高且趋于稳定,但经充实训练之后,val_loss值仍旧均位于1以上,可能与该模子的学习率过高、数据集处理不当、邻居采样不足等标题,以是此实例demo有待改进。
四、GraphSAGE的优势与未来预测
通过上述实验,我们可以看到GraphSAGE在Cora数据集上的体现非常出色。相比于传统的GCN,GraphSAGE的采样机制使其能够更好地扩展到大规模图数据,同时保持较高的分类精度。
(1)优势
- 高效性 :通过采样邻居节点,制止了对整个图的计算,显著降低了时间和空间复杂度。
- 机动性 :支持多种聚合方式,可以根据详细任务选择符合的战略。
- 可扩展性 :实用于动态图和超大规模图。
(2)未来预测
尽管GraphSAGE已经取得了显著的结果,但仍有许多值得探索的方向:
- 更高效的采样战略 :如何计划更智能的采样方法,进一步提升模子性能?
- 跨领域应用 :如何将GNN应用于更多领域,例如健康估计、寿命预测、生物信息学、金融分析等?
- 理论分析 :深入研究GNN的表达能力和泛化能力。
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。 |