PyTorch 学习笔记(四):自定义 Dataset 和输入流

PyTorch入门实战教程

什么是Datasets:

在输入流水线中,我们看到准备数据的代码是这么写的data = datasets.CIFAR10(“./data/”, transform=transform, train=True, download=True)。datasets.CIFAR10就是一个Datasets子类,data是这个类的一个实例。

为什么要定义Datasets:

PyTorch提供了一个工具函数torch.utils.data.DataLoader。通过这个类,我们在准备mini-batch的时候可以多线程并行处理,这样可以加快准备数据的速度。Datasets就是构建这个类的实例的参数之一。

如何自定义Datasets

下面是一个自定义Datasets的框架:

下面看一下官方MNIST的例子(代码被缩减,只留下了重要的部分):

 

文章来源:Keith

PyTorch入门实战教程

Leave a Reply

Your email address will not be published. Required fields are marked *

返回顶部