fastai 系列教程(三)- CIFAR10 示例

PyTorch入门实战教程

我们这节教程用 CIFAR10 图像分类作为例子,讲解 fastai 中 ConvLearner 的用法。如果对 fastai 还不熟悉,可以参考下面两篇教程,文末有本文代码的 jupyter notebook 供大家自己测试。

  1. fastai 系列教程(一)- 安装
  2. fastai 系列教程(二)- 快速入门 MNIST 示例

准备 CIFAR10 数据

在新版的 fastai 里,我们可以使用 untar_data 来自动下载一些预定义好的数据。首先当然是引入必要的包:

下载数据:

如果打印 CIFAR_PATH 的值,发现是:

到相应的文件夹,可以看到主要是两个文件夹:train 和 test,还有一个 labels.txt 包含了所有的 label 数据:

我们可以打印一张图看看:


初始化模型

我们先定义一组对图像的增广:

然后我们用新的 image_data_from_folder 函数来从数据文件夹读取数据,并做图像增广:

关于image_data_from_folder 的用法,可以参考官方 API 文档

接下来,我们定义 Learner 类(Learner 类用法):

这里的 wrn_22 指的是 WideResNet,我们这里载入预训练模型。这个文件详细定义了 wrn_22 的模型结构。

训练

我们选择用 Learner 类的 fit 方法来训练:

代码很简洁,只需要一行,这也是 fastai 库的魅力所在,可以快速实现很多原型的测试与复用。

Notebook 下载

本节的 Jupyter notebook 可到 Github 下载

本站微信群、QQ群(三群号 726282629):

PyTorch入门实战教程

发表回复

您的电子邮箱地址不会被公开。

返回顶部