怎么使用Pytorch+PyG实现GraphSAGE

其他教程   发布日期:2025年03月29日   浏览次数:123

这篇文章主要讲解了“怎么使用Pytorch+PyG实现GraphSAGE”,文中的讲解内容简单清晰,易于学习与理解,下面请大家跟着小编的思路慢慢深入,一起来研究和学习“怎么使用Pytorch+PyG实现GraphSAGE”吧!

GraphSAGE简介

GraphSAGE(Graph Sampling and Aggregation)是一种常见的图神经网络模型,主要用于结点级别的表征学习。该模型基于采样和聚合策略,将一个结点及其邻居节点信息融合在一起,得到其表征表示,并通过多轮迭代更新来提高表征的精度。

实现步骤

数据准备

在本次实现中,我们仍然使用Cora数据集作为示例进行测试,由于GraphSage主要聚焦于单一节点特征的更新,因此这里不需要对数据集做特别处理,只需要将数据转化成PyG格式即可。

  1. import torch.nn.functional as F
  2. from torch_geometric.datasets import Planetoid
  3. from torch_geometric.utils import from_networkx, to_networkx
  4. # 加载cora数据集
  5. dataset = Planetoid(root='./cora', name='Cora')
  6. data = dataset[0]
  7. # 将nx.Graph形式的图转换成PyG需要的格式
  8. graph = to_networkx(data)
  9. data = from_networkx(graph)
  10. # 获取节点数量和特征向量维度
  11. num_nodes = data.num_nodes
  12. num_features = dataset.num_features
  13. num_classes = dataset.num_classes
  14. # 建立需要训练的节点分割数据集
  15. data.train_mask = torch.zeros(num_nodes, dtype=torch.bool)
  16. data.val_mask = torch.zeros(num_nodes, dtype=torch.bool)
  17. data.test_mask = torch.zeros(num_nodes, dtype=torch.bool)
  18. data.train_mask[:num_nodes - 1000] = True
  19. data.test_mask[-1000:] = True
  20. data.val_mask[num_nodes - 2000: num_nodes - 1000] = True

实现模型

接下来,我们需要定义GraphSAGE模型。与传统的GCN中只需要一层卷积操作不同,GraphSAGE包含两层卷积和采样(也称“聚合”)操作。

  1. from torch.nn import Sequential as Seq, Linear as Lin, ReLU
  2. from torch_geometric.nn import SAGEConv
  3. class GraphSAGE(torch.nn.Module):
  4. def __init__(self, hidden_channels, num_layers):
  5. super(GraphSAGE, self).__init__()
  6. self.convs = nn.ModuleList()
  7. for i in range(num_layers):
  8. in_channels = hidden_channels if i != 0 else num_features
  9. out_channels = num_classes if i == num_layers - 1 else hidden_channels
  10. self.convs.append(SAGEConv(in_channels, out_channels))
  11. def forward(self, x, edge_index):
  12. for _, conv in enumerate(self.convs[:-1]):
  13. x = F.relu(conv(x, edge_index))
  14. # 最后一层不用激活函数
  15. x = self.convs[-1](x, edge_index)
  16. return F.log_softmax(x, dim=-1)

在上述代码中,我们实现了多层GraphSAGE卷积和相应的聚合函数,并使用ReLU和softmax函数来进行特征提取和分类分数的输出。

模型训练

定义好模型之后,就可以开始针对Cora数据集进行模型训练。首先还是需要先指定优化器和损失函数,并设定一些参数用于记录训练过程中的信息,如Epochs、Batch size、学习率等。

  1. # 初始化GraphSage并指定参数
  2. num_layers = 2
  3. hidden_channels = 256
  4. model = GraphSAGE(hidden_channels, num_layers).to(device)
  5. optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
  6. loss_func = nn.CrossEntropyLoss()
  7. # 训练过程
  8. for epoch in range(500):
  9. model.train()
  10. optimizer.zero_grad()
  11. out = model(data.x.to(device), data.edge_index.to(device))
  12. loss = loss_func(out[data.train_mask], data.y.to(device)[data.train_mask])
  13. loss.backward()
  14. optimizer.step()
  15. # 在各个测试阶段检测一下准确率
  16. if epoch % 10 == 0:
  17. with torch.no_grad():
  18. _, pred = model(data.x.to(device), data.edge_index.to(device)).max(dim=1)
  19. correct = float(pred[data.test_mask].eq(data.y.to(device)[data.test_mask]).sum().item())
  20. acc = correct / data.test_mask.sum().item()
  21. print("Epoch {:03d}, Train Loss {:.4f}, Test Acc {:.4f}".format(
  22. epoch, loss.item(), acc))

在上述代码中,我们使用有标记的训练数据拟合GraphSAGE模型,在各个验证阶段测试准确率,并通过梯度下降法优化损失函数。

以上就是怎么使用Pytorch+PyG实现GraphSAGE的详细内容,更多关于怎么使用Pytorch+PyG实现GraphSAGE的资料请关注九品源码其它相关文章!