GraphSAGE项目练手

打印 上一主题 下一主题

主题 672|帖子 672|积分 2016

  1. # 导包
  2. from torch_geometric.datasets import Planetoid
  3. from torch_geometric.loader import NeighborLoader
  4. from torch_geometric.utils import to_networkx
  5. import numpy as np
  6. import networkx as nx
  7. import matplotlib.pyplot as plt
  8. import torch
  9. import torch.nn.functional as F
  10. from torch_geometric.nn import SAGEConv
  11. # 导入PubMed数据集
  12. dataset = Planetoid(root='',name='Pubmed')
  13. data = dataset[0]
  14. # 邻居采样
  15. # 使用NeighborLoader 来完成这一任务。
  16. # 保留目的节点的10个邻居 和 其邻居的10个邻居, 对60个目的节点进行分组,每16个目的节点为一组
  17. # 进行采样
  18. train_loader = NeighborLoader(
  19.     data,# 数据源
  20.     num_neighbors=[5,10], # 每一层采样的邻居采样量,第一层5,第二层10
  21.     batch_size=16,
  22.     input_nodes=data.train_mask # 60个训练目的节点
  23. )
  24. # 遍历数据检验
  25. # for i,subgraph in enumerate(train_loader):
  26. #     print(f'Subgraph{i}:{subgraph}')
  27. # 子图可视化
  28. # fig = plt.figure(figsize=(16,16))
  29. # for idx,(subdata,pos) in enumerate(zip(train_loader,[221,222,223,224])):
  30. #     G = to_networkx(subdata,to_undirected=True)
  31. #     ax = fig.add_subplot(pos)
  32. #     ax.set_title(f'Subgraph{idx},fonts=24')
  33. #     plt.axis('off')
  34. #     nx.draw_networkx(G,pos=nx.spring_layout(G),with_labels=False,node_color=subdata.y)
  35. # plt.show()
  36. # 实现准确率评估模型
  37. def  accuracy(pre_y,y):
  38.     return ((pre_y==y).sum() / len(y)).item()
  39. # 定义GraphSAGE
  40. class GraphSAGE(torch.nn.Module):
  41.     def __init__(self,dim_in,dim_h,dim_out):
  42.         super().__init__()
  43.         self.sage1 = SAGEConv(dim_in,dim_h)
  44.         self.sage2= SAGEConv(dim_h,dim_out)
  45.     def forward(self,x,edge_index):
  46.         h = self.sage1(x,edge_index)
  47.         h = torch.relu(h)
  48.         h = F.dropout(h,p=0.5,training=self.training)
  49.         h = self.sage2(h,edge_index)
  50.         return h
  51. # 使用小批量训练,Fit函数要修改为先循环epoch次,然后循环批数据,以在每个批数据上训练epoch次
  52.     def fit(self,loader,epochs):
  53.         criterion = torch.nn.CrossEntropyLoss()
  54.         optimizer = torch.optim.Adam(self.parameters(),lr=0.01)
  55.         self.train()
  56.         for epoch in range(epochs+1):
  57.             total_loss = 0
  58.             acc = 0
  59.             val_loss = 0
  60.             val_acc = 0
  61.             for batch in loader:
  62.                 optimizer.zero_grad()
  63.                 out = self(batch.x, batch.edge_index)
  64.                 loss = criterion(out[batch.train_mask], batch.y[batch.train_mask])
  65.                 total_loss += loss.item()
  66.                 acc += accuracy(out[batch.train_mask].argmax(dim=1), batch.y[batch.train_mask])
  67.                 loss.backward()
  68.                 optimizer.step()
  69.                 # Validation
  70.                 val_loss += criterion(out[batch.val_mask], batch.y[batch.val_mask])
  71.                 val_acc += accuracy(out[batch.val_mask].argmax(dim=1), batch.y[batch.val_mask])
  72.                 if epoch % 20 == 0:
  73.                     print(f'Epoch {epoch:>3} | Train Loss: {loss/len(loader):.3f} | Train Acc: {acc/len(loader)*100:>6.2f}% | Val Loss: {val_loss/len(train_loader):.2f} | Val Acc: {val_acc/len(train_loader)*100:.2f}%')
  74. @torch.no_grad()
  75. def test(self, data):
  76.     self.eval()
  77.     out = self(data.x, data.edge_index)
  78.     acc = accuracy(out.argmax(dim=1)[data.test_mask], data.y[data.test_mask])
  79.     return acc
  80. # Create GraphSAGE
  81. graphsage = GraphSAGE(dataset.num_features, 64, dataset.num_classes)
  82. print(graphsage)
  83. # Train
  84. graphsage.fit(train_loader, 200)
复制代码
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。
回复

使用道具 举报

0 个回复

倒序浏览

快速回复

您需要登录后才可以回帖 登录 or 立即注册

本版积分规则

滴水恩情

金牌会员
这个人很懒什么都没写!

标签云

快速回复 返回顶部 返回列表