PyTorch C++ API 系列 4:实现猫狗分类器(一)

PyTorch入门实战教程

在上一节中,我们介绍了怎样使用 PyTorch C++ API 实现自定义数据集读取并且完成网络训练。这一节,我们将实现一个具体的实例:实现猫狗分类器。

完整的 PyTorch C++ 系列教程目录如下(或者点击这里查看):

  1. PyTorch C++ API 系列 1:用 VGG-16 识别 MNIST
  2. PyTorch C++ API 系列 2:使用自定义数据集
  3. PyTorch C++ API 系列 3:训练网络
  4. PyTorch C++ API 系列 4:实现猫狗分类器(一)
  5. PyTorch C++ API 系列 5:实现猫狗分类器(二)

数据集

这次的猫狗数据集来自 Kaggle,链接在此,需要的朋友可以下载到本地实验。

数据集的训练数据包含 25k 张图片,都是猫或狗的照片。例如:

数据读取

我们把猫的图片标记为 0,把狗的图片标记为 1。数据分为两个压缩文件:

  1. train.zip:所有训练集
  2. test.zip:所有测试集

在训练集里,图片的命名方式为 <class>.<number>.jpg,其中:

  • class 是 0 或 1
  • number 代表序列号

我们将所有猫的图片放到 train/cat 文件夹里,所有狗的图片放到 train/dog 文件夹里。这一步操作可以用 Python 的 shutil 模块解决:

然后我们就可以定义读取数据的函数了,我们之前在这里讲过,需要的可以回去参考。这里我们主要的函数有以下几个:

  • load_data_from_folder:读取文件路径和对应的 label,文件路径为 string,label 为 int。
  • process_image:这个函数主要是处理图像,包括读取、调整大小、转换成 tensor 等,然后返回这个 tensor。
  • process_labels:这个函数主要返回 label 的 tensor。

这里是代码:

然后,我们就可以初始化 Dataset 了:

网络结构

这里是我们用到的网络结构:

然后在训练过程中我们初始化这个网络并且传入我们的训练数据:

训练

接下来就是训练网络了,这一部分我们放在下一节讲解。

PyTorch入门实战教程

发表评论

电子邮件地址不会被公开。 必填项已用*标注

返回顶部