PyTorch C++ API 系列 2:使用自定义数据集

PyTorch入门实战教程

上一篇文章中,我们讨论了如何使用 PyTorch C++ API 实现 VGG-16 来识别 MNIST 数据集。这篇文章我们讨论一下如何用 C++ API 使用自定义数据集。

完整的 PyTorch C++ 系列教程目录如下(或者点击这里查看):

  1. PyTorch C++ API 系列 1:用 VGG-16 识别 MNIST
  2. PyTorch C++ API 系列 2:使用自定义数据集
  3. PyTorch C++ API 系列 3:训练网络
  4. PyTorch C++ API 系列 4:实现猫狗分类器(一)
  5. PyTorch C++ API 系列 5:实现猫狗分类器(二)

概览

我们先来看一下上一篇教程中我们是怎么读取数据的:

我们来细细讲解。

首先,我们将数据集读入 tensor:

接下来,我们应用一些 transforms

我们 batch_size 为 64:

然后我们就可以将数据传给 data loader 然后由 data loader 传给网络。

工作原理

我们需要了解一下这背后到底是怎么工作的,因此,我们看一下 MNIST 读取的源码文件 torch::data::datasets::MNIST 类,源码地址在这里

对于 MNIST 类的构造器:

我们可以看到这里调用了两个函数:

  1. read_images(root, mode) 读取图像
  2. read_targets(root, mode) 读取图像标签

我们来看看这两个函数具体怎么工作。

read_images(root, mode)

read_targets(root, mode)

还有一些辅助函数:

上面两个函数分别用于获取一个图像及其标签,和返回数据集大小。

自定义数据集的流程

通过上面的源代码查看,我们知道了自定义数据集的大概流程:

  1. 读取数据和标签
  2. 转换成 tensor
  3. 定义 get()size() 两个函数
  4. 初始化类
  5. 将类实例传给 data loader

自定义数据集示例

接下来我们看一个具体示例。下面是整个代码的大体框架:

这里我们使用 OpenCV 来读取图像数据,读取的方法相对比较简单:

注意要转换成 PyTorch 使用的 tensor 顺,即 batch_size, channels, height, width

读取标签:

最终代码:

在下一篇教程中,我们将介绍如何在 CNN 中使用自定义的 data loader。

本站微信群、QQ群(三群号 726282629):

PyTorch入门实战教程

发表回复

您的电子邮箱地址不会被公开。

返回顶部