PyTorch C++ API 系列 5:实现猫狗分类器(二)

PyTorch入门实战教程

上一节我们介绍了数据集,定义了读取数据的函数,定义了网络结构,这一节我们来训练网络。

完整的 PyTorch C++ 系列教程目录如下(或者点击这里查看):

  1. PyTorch C++ API 系列 1:用 VGG-16 识别 MNIST
  2. PyTorch C++ API 系列 2:使用自定义数据集
  3. PyTorch C++ API 系列 3:训练网络
  4. PyTorch C++ API 系列 4:实现猫狗分类器(一)
  5. PyTorch C++ API 系列 5:实现猫狗分类器(二)

之前我们用了一个类似 VGG-16 的网络,多加了一层全连接层在最后面做分类。这里我们可能要对网络输入的图像做一些改变,以便我们能在 CPU 上快速训练:

  1. 输入图像的大小变成 64x64x3
  2. 只使用 2 个卷积层和 2 个最大池化层来训练

这些改变肯定对最后的准确率有所影响,但我们教程的目的只是为了教大家如何使用 PyTorch 的 C++ API。

网络结构

下面是我们上一节里定义的最初的网络,放在这里供大家回顾一下:

可以看到,网络里有 13 个卷积层,5 个最大池化层,4 个全连接层。

新的网络如下:

新的网络只包含 2 个卷积层,2 个最大池化层,3 个全连接层。最为实验目的已经足够了。

训练网络

下面是我们训练这个网络的大概流程:

  1. 把网络设置为 train 模式 net->train()
  2. 对于每个 batch 的数据循环:
    1. 得到数据和对应的标签
    2. 清空 gradients
    3. 前向传播
    4. 计算损失函数
    5. 后向传播
    6. 更新参数
    7. 计算 training accuracy 和 mean square error
  3. 保存训练好的模型。

上述流程的具体代码为:

同样,我们的 test 函数如下:

训练结果

在我们训练了 100 个 epoch 之后,得到下列准确率:

  1. 最好的 training accuracy: 99.82%
  2. 最好的 testing accuracy: 82.43%

我们来看一下结果:

正确的分类示例

狗:

猫:

错误的分类示例

狗:

猫:

好了,这一节就到这里。

本站微信群、QQ群(三群号 726282629):

PyTorch入门实战教程

发表回复

您的电子邮箱地址不会被公开。

评论列表(2)

返回顶部