Pytorch+PyG实现GraphSAGE过程示例详解
GraphSAGE简介
GraphSAGE(Graph Sampling and Aggregation)是一种常见的图神经网络模型,主要用于结点级别的表征学习。该模型基于采样和聚合策略,将一个结点及其邻居节点信息融合在一起,得到其表征表示,并通过多轮迭代更新来提高表征的精度。
实现步骤
数据准备
在本次实现中,我们仍然使用Cora数据集作为示例进行测试,由于GraphSage主要聚焦于单一节点特征的更新,因此这里不需要对数据集做特别处理,只需要将数据转化成PyG格式即可。
import torch.nn.functional as F
from torch_geometric.datasets import Planetoid
from torch_geometric.utils import from_networkx, to_networkx
# 加载cora数据集
dataset = Planetoid(root='./cora', name='Cora')
data = dataset[0]
# 将nx.Graph形式的图转换成PyG需要的格式
graph = to_networkx(data)
data = from_networkx(graph)
# 获取节点数量和特征向量维度
num_nodes = data.num_nodes
num_features = dataset.num_features
num_classes = dataset.num_classes
# 建立需要训练的节点分割数据集
data.train_mask = torch.zeros(num_nodes, dtype=torch.bool)
data.val_mask = torch.zeros(num_nodes, dtype=torch.bool)
data.test_mask = torch.zeros(num_nodes, dtype=torch.bool)
data.train_mask[:num_nodes - 1000] = True
data.test_mask[-1000:] = True
data.val_mask[num_nodes - 2000: num_nodes - 1000] = True
实现模型
接下来,我们需要定义GraphSAGE模型。与传统的GCN中只需要一层卷积操作不同,GraphSAGE包含两层卷积和采样(也称“聚合”)操作。
from torch.nn import Sequential as Seq, Linear as Lin, ReLU
from torch_geometric.nn import SAGEConv
class GraphSAGE(torch.nn.Module):
def __init__(self, hidden_channels, num_layers):
super(GraphSAGE, self).__init__()
self.convs = nn.ModuleList()
for i in range(num_layers):
in_channels = hidden_channels if i != 0 else num_features
out_channels = num_classes if i == num_layers - 1 else hidden_channels
self.convs.append(SAGEConv(in_channels, out_channels))
def forward(self, x, edge_index):
for _, conv in enumerate(self.convs[:-1]):
x = F.relu(conv(x, edge_index))
# 最后一层不用激活函数
x = self.convs[-1](x, edge_index)
return F.log_softmax(x, dim=-1)
在上述代码中,我们实现了多层GraphSAGE卷积和相应的聚合函数,并使用ReLU和softmax函数来进行特征提取和分类分数的输出。
模型训练
定义好模型之后,就可以开始针对Cora数据集进行模型训练。首先还是需要先指定优化器和损失函数,并设定一些参数用于记录训练过程中的信息,如Epochs、Batch size、学习率等。
# 初始化GraphSage并指定参数
num_layers = 2
hidden_channels = 256
model = GraphSAGE(hidden_channels, num_layers).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
loss_func = nn.CrossEntropyLoss()
# 训练过程
for epoch in range(500):
model.train()
optimizer.zero_grad()
out = model(data.x.to(device), data.edge_index.to(device))
loss = loss_func(out[data.train_mask], data.y.to(device)[data.train_mask])
loss.backward()
optimizer.step()
# 在各个测试阶段检测一下准确率
if epoch % 10 == 0:
with torch.no_grad():
_, pred = model(data.x.to(device), data.edge_index.to(device)).max(dim=1)
correct = float(pred[data.test_mask].eq(data.y.to(device)[data.test_mask]).sum().item())
acc = correct / data.test_mask.sum().item()
print("Epoch {:03D}, Train Loss {:.4f}, Test Acc {:.4f}".fORMat(
epoch, loss.item(), acc))
在上述代码中,我们使用有标记的训练数据拟合GraphSAGE模型,在各个验证阶段测试准确率,并通过梯度下降法优化损失函数。
以上就是PyTorch+PyG实现GraphSAGE过程示例详解的详细内容,更多关于Pytorch PyG实现GraphSAGE的资料请关注其它相关文章!
相关文章