- # 导包
- from torch_geometric.datasets import Planetoid
- from torch_geometric.loader import NeighborLoader
- from torch_geometric.utils import to_networkx
- import numpy as np
- import networkx as nx
- import matplotlib.pyplot as plt
- import torch
- import torch.nn.functional as F
- from torch_geometric.nn import SAGEConv
- # 导入PubMed数据集
- dataset = Planetoid(root='',name='Pubmed')
- data = dataset[0]
- # 邻居采样
- # 使用NeighborLoader 来完成这一任务。
- # 保留目的节点的10个邻居 和 其邻居的10个邻居, 对60个目的节点进行分组,每16个目的节点为一组
- # 进行采样
- train_loader = NeighborLoader(
- data,# 数据源
- num_neighbors=[5,10], # 每一层采样的邻居采样量,第一层5,第二层10
- batch_size=16,
- input_nodes=data.train_mask # 60个训练目的节点
- )
- # 遍历数据检验
- # for i,subgraph in enumerate(train_loader):
- # print(f'Subgraph{i}:{subgraph}')
- # 子图可视化
- # fig = plt.figure(figsize=(16,16))
- # for idx,(subdata,pos) in enumerate(zip(train_loader,[221,222,223,224])):
- # G = to_networkx(subdata,to_undirected=True)
- # ax = fig.add_subplot(pos)
- # ax.set_title(f'Subgraph{idx},fonts=24')
- # plt.axis('off')
- # nx.draw_networkx(G,pos=nx.spring_layout(G),with_labels=False,node_color=subdata.y)
- # plt.show()
- # 实现准确率评估模型
- def accuracy(pre_y,y):
- return ((pre_y==y).sum() / len(y)).item()
- # 定义GraphSAGE
- class GraphSAGE(torch.nn.Module):
- def __init__(self,dim_in,dim_h,dim_out):
- super().__init__()
- self.sage1 = SAGEConv(dim_in,dim_h)
- self.sage2= SAGEConv(dim_h,dim_out)
- def forward(self,x,edge_index):
- h = self.sage1(x,edge_index)
- h = torch.relu(h)
- h = F.dropout(h,p=0.5,training=self.training)
- h = self.sage2(h,edge_index)
- return h
- # 使用小批量训练,Fit函数要修改为先循环epoch次,然后循环批数据,以在每个批数据上训练epoch次
- def fit(self,loader,epochs):
- criterion = torch.nn.CrossEntropyLoss()
- optimizer = torch.optim.Adam(self.parameters(),lr=0.01)
- self.train()
- for epoch in range(epochs+1):
- total_loss = 0
- acc = 0
- val_loss = 0
- val_acc = 0
- for batch in loader:
- optimizer.zero_grad()
- out = self(batch.x, batch.edge_index)
- loss = criterion(out[batch.train_mask], batch.y[batch.train_mask])
- total_loss += loss.item()
- acc += accuracy(out[batch.train_mask].argmax(dim=1), batch.y[batch.train_mask])
- loss.backward()
- optimizer.step()
- # Validation
- val_loss += criterion(out[batch.val_mask], batch.y[batch.val_mask])
- val_acc += accuracy(out[batch.val_mask].argmax(dim=1), batch.y[batch.val_mask])
- if epoch % 20 == 0:
- print(f'Epoch {epoch:>3} | Train Loss: {loss/len(loader):.3f} | Train Acc: {acc/len(loader)*100:>6.2f}% | Val Loss: {val_loss/len(train_loader):.2f} | Val Acc: {val_acc/len(train_loader)*100:.2f}%')
- @torch.no_grad()
- def test(self, data):
- self.eval()
- out = self(data.x, data.edge_index)
- acc = accuracy(out.argmax(dim=1)[data.test_mask], data.y[data.test_mask])
- return acc
- # Create GraphSAGE
- graphsage = GraphSAGE(dataset.num_features, 64, dataset.num_classes)
- print(graphsage)
- # Train
- graphsage.fit(train_loader, 200)
复制代码 免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。 |