PyTorch 中自定义数据集的读取方法小结

PyTorch入门实战教程

虽然说网上关于 PyTorch 数据集读取的文章和教程多的很,但总觉得哪里不对,尤其是对新手来说,可能需要很长一段时间来钻研和尝试。所以这里我们 PyTorch 中文网为大家总结常用的几种自定义数据集(Custom Dataset)的读取方式(采用 Dataloader)。

本文将涉及以下几个方面:

  • 自定义数据集基础方法
  • 使用 Torchvision Transforms
  • 换一种方法使用 Torchvision Transforms
  • 结合 Pandas 读取 csv 文件
  • 结合 Pandas 使用 __getitem__()
  • 使用 Dataloader 读取自定义数据集

自定义数据集基础方法

首先要创建一个 Dataset 类:

这个代码中:

  • __init__() 一些初始化过程写在这里
  • __len__() 返回所有数据的数量
  • __getitem__() 返回数据和标签,可以这样显示调用:

使用 Torchvision Transforms

Transform 最常见的使用方法是:

换一种方法使用 Torchvision Transforms

有些人不喜欢把 transform 操作写在 Dataset 外面(上面代码里的注释 1),所以还有一种写法:

结合 Pandas 读取 csv 文件

假如说我们想从一个 csv 文件中用 Pandas 读取数据。一个 csv 示例如下:

File Name Label Extra Operation
tr_0.png 5 TRUE
tr_1.png 0 FALSE
tr_1.png 4 FALSE

如果我们需要在自定义数据集里从这个 csv 文件读取文件名,可以这样做:

结合 Pandas 使用 __getitem__()

另一种情况是 csv 文件中保存了我们需要的图像文件的像素值(比如有些 MNIST 教程就是这样的)。我们需要改动一下 __getitem__() 函数。

Label pixel_1 pixel_2
1 50 99
0 21 223
9 44 112
代码如下:

使用 Dataloader 读取自定义数据集

PyTorch 中的 Dataloader 只是调用 __getitem__() 方法并组合成 batch,我们可以这样调用:

需要注意的是使用多卡训练时,PyTorch dataloader 会将每个 batch 平均分配到各个 GPU。所以如果 batch size 过小,可能发挥不了多卡的效果。

PyTorch入门实战教程

Leave a Reply

Your email address will not be published. Required fields are marked *

返回顶部