Pytorch+PyG实现EdgeCNN过程示例详解

2023-05-17 05:05:40 示例 过程 详解

1.EdgeCNN简介

EdgeCNN是一种用于图像点云处理的卷积神经网络(Convolutional Neural Network,CNN)模型。与传统的CNN仅能处理图片二维数据不同,EdgeCNN可以对三维点云中每个点周围的局部邻域进行操作,并适用于物体识别、深度估计、自动驾驶等多项任务。

2. 实现步骤

2.1 数据准备

在本实验中,我们使用了一个包含4万个点云的数据集ModelNet10,作为示例。与其它标准图像数据集不同的是,这个数据集中图形的构成量非常大,而且各图之间结构差异很大,因此需要进行大量的预处理工作。

# 导入模型数据集
from torch_geometric.datasets import ModelNet
# 加载ModelNet数据集
dataset = ModelNet(root='./modelnet', name='10')
data = dataset[0]
# 定义超级参数
num_points = 1024
batch_size = 32
train_dataset_size = 8000
# 将数据集分割成训练、验证及测试三个数据集
train_dataset = data[0:train_dataset_size]
val_dataset = data[train_dataset_size: 9000]
test_dataset = data[9000:]
# 定义数据加载批处理器
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

通过上述代码,我们先是导入ModelNet数据集并将其分割成训练、验证及测试三个数据集,并创建了数据加载批处理器,以便于在训练过程中对这些数据进行有效的处理。

2.2 实现模型

在定义EdgeCNN模型时,我们需要根据图像点云经常使用的架构定义网络结构。同时,在实现卷积操作时应引入相应的邻域信息,来使得网络能够学习到系统中附近点之间的关系。

from torch.nn import Sequential as Seq, Linear as Lin, ReLU
from torch_geometric.nn import EdgeConv, global_max_pool
class EdgeCNN(torch.nn.Module):
    def __init__(self, dataset):
        super(EdgeCNN, self).__init__()
        # 定义基础参数
        self.input_dim = dataset.num_features
        self.output_dim = dataset.num_classes
        self.num_points = num_points
        # 定义模型结构
        self.conv1 = EdgeConv(Seq(Lin(self.input_dim, 32), ReLU()))
        self.conv2 = EdgeConv(Seq(Lin(32, 64), ReLU()))
        self.conv3 = EdgeConv(Seq(Lin(64, 128), ReLU()))
        self.conv4 = EdgeConv(Seq(Lin(128, 256), ReLU()))
        self.fc1 = torch.nn.Linear(256, 1024)
        self.fc2 = torch.nn.Linear(1024, self.output_dim)
    def forward(self, pos, batch):
        # 构造图
        edge_index = radius_graph(pos, r=0.6, batch=batch, loop=False)
        # 第一层CNN模型的卷积 + 池化处理
        x = F.relu(self.conv1(x=pos, edge_index=edge_index))
        x = global_max_pool(x, batch)
        # 第二层CNN模型的卷积 + 池化处理
        edge_index = radius_graph(x, r=0.9, batch=batch, loop=False)
        x = F.relu(self.conv2(x=x, edge_index=edge_index))
        x = global_max_pool(x, batch)
        # 第三层CNN模型的卷积 + 池化处理
        edge_index = radius_graph(x, r=1.2, batch=batch, loop=False)
        x = F.relu(self.conv3(x=x, edge_index=edge_index))
        x = global_max_pool(x, batch)
        # 第四层CNN模型的卷积 + 池化处理
        edge_index = radius_graph(x, r=1.5, batch=batch, loop=False)
        x = F.relu(self.conv4(x=x, edge_index=edge_index))
        # 定义全连接网络
        x = global_max_pool(x, batch)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=-1)

在上述代码中,实现了基于EdgeCNN的模型的各个卷积层和全连接层,并使用radius_graph等函数将局部区域问题归约到定义的卷积核检测范围之内,以便更好地对点进行分析和特征提取。最后结合全连接层输出一个维度为类别数的向量,并通过softmax函数来计算损失。

2.3 模型训练

在定义好EdgeCNN网络结构之后,我们还需要指定合适的优化器、损失函数,并控制训练轮数、批大小与学习率等超参数。同时也需要记录大量日志信息,方便后期跟踪及管理。

# 定义训练计划,包括损失函数、优化器及迭代次数等
train_epochs = 50
learning_rate = 0.01
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(edge_cnn.parameters(), lr=learning_rate)
losses_per_epoch = []
accuracies_per_epoch = []
for epoch in range(train_epochs):
    running_loss = 0.0
    running_corrects = 0.0
    count = 0.0
    for samples in train_loader:
        optimizer.zero_grad()
        pos, batch, label = samples.pos, samples.batch, samples.y.to(torch.long)
        out = edge_cnn(pos, batch)
        loss = criterion(out, label)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() / len(train_dataset)
        running_corrects += torch.sum(torch.argmax(out, dim=1) == label).item() / len(train_dataset)
        count += 1
    losses_per_epoch.append(running_loss)
    accuracies_per_epoch.append(running_corrects)
    if (epoch + 1) % 5 == 0:
        print("Train Epoch {}/{} Loss {:.4f} Accuracy {:.4f}".fORMat(
            epoch + 1, train_epochs, running_loss, running_corrects))

在训练过程中,我们遍历每个batch,通过反向传播算法进行优化,并更新loss及accuracy输出。同时,为了方便可视化与记录,需要将训练过程中的loss和accuracy输出到相应的容器中,以便后期进行分析和处理。

以上就是PyTorch+PyG实现EdgeCNN过程示例详解的详细内容,更多关于Pytorch PyG实现EdgeCNN的资料请关注其它相关文章!

相关文章